diff --git a/go/base/context.go b/go/base/context.go index 52568ba..6e28cb0 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -5,8 +5,17 @@ package base -import () +import ( + "fmt" + "strings" + "sync/atomic" + "time" + "github.com/github/gh-osc/go/mysql" + "github.com/github/gh-osc/go/sql" +) + +// RowsEstimateMethod is the type of row number estimation type RowsEstimateMethod string const ( @@ -15,18 +24,41 @@ const ( CountRowsEstimate = "CountRowsEstimate" ) +const ( + maxRetries = 10 +) + +// MigrationContext has the general, global state of migration. It is used by +// all components throughout the migration process. type MigrationContext struct { - DatabaseName string - OriginalTableName string - GhostTableName string - AlterStatement string - TableEngine string - CountTableRows bool - RowsEstimate int64 - UsedRowsEstimateMethod RowsEstimateMethod - ChunkSize int - OriginalBinlogFormat string - OriginalBinlogRowImage string + DatabaseName string + OriginalTableName string + AlterStatement string + TableEngine string + CountTableRows bool + RowsEstimate int64 + UsedRowsEstimateMethod RowsEstimateMethod + ChunkSize int64 + OriginalBinlogFormat string + OriginalBinlogRowImage string + AllowedRunningOnMaster bool + InspectorConnectionConfig *mysql.ConnectionConfig + MasterConnectionConfig *mysql.ConnectionConfig + MigrationRangeMinValues *sql.ColumnValues + MigrationRangeMaxValues *sql.ColumnValues + Iteration int64 + MigrationIterationRangeMinValues *sql.ColumnValues + MigrationIterationRangeMaxValues *sql.ColumnValues + UniqueKey *sql.UniqueKey + StartTime time.Time + RowCopyStartTime time.Time + CurrentLag int64 + MaxLagMillisecondsThrottleThreshold int64 + ThrottleFlagFile string + TotalRowsCopied int64 + + IsThrottled func() bool + CanStopStreaming func() bool } var context *MigrationContext @@ -37,15 +69,75 @@ func init() { func newMigrationContext() *MigrationContext { return &MigrationContext{ - ChunkSize: 1000, + ChunkSize: 1000, + InspectorConnectionConfig: mysql.NewConnectionConfig(), + MasterConnectionConfig: mysql.NewConnectionConfig(), + MaxLagMillisecondsThrottleThreshold: 1000, } } +// GetMigrationContext func GetMigrationContext() *MigrationContext { return context } -// RequiresBinlogFormatChange +// GetGhostTableName generates the name of ghost table, based on original table name +func (this *MigrationContext) GetGhostTableName() string { + return fmt.Sprintf("_%s_New", this.OriginalTableName) +} + +// GetChangelogTableName generates the name of changelog table, based on original table name +func (this *MigrationContext) GetChangelogTableName() string { + return fmt.Sprintf("_%s_OSC", this.OriginalTableName) +} + +// RequiresBinlogFormatChange is `true` when the original binlog format isn't `ROW` func (this *MigrationContext) RequiresBinlogFormatChange() bool { return this.OriginalBinlogFormat != "ROW" } + +// IsRunningOnMaster is `true` when the app connects directly to the master (typically +// it should be executed on replica and infer the master) +func (this *MigrationContext) IsRunningOnMaster() bool { + return this.InspectorConnectionConfig.Equals(this.MasterConnectionConfig) +} + +// HasMigrationRange tells us whether there's a range to iterate for copying rows. +// It will be `false` if the table is initially empty +func (this *MigrationContext) HasMigrationRange() bool { + return this.MigrationRangeMinValues != nil && this.MigrationRangeMaxValues != nil +} + +func (this *MigrationContext) MaxRetries() int { + return maxRetries +} + +func (this *MigrationContext) IsTransactionalTable() bool { + switch strings.ToLower(this.TableEngine) { + case "innodb": + { + return true + } + case "tokudb": + { + return true + } + } + return false +} + +// ElapsedTime returns time since very beginning of the process +func (this *MigrationContext) ElapsedTime() time.Duration { + return time.Now().Sub(this.StartTime) +} + +// ElapsedRowCopyTime returns time since starting to copy chunks of rows +func (this *MigrationContext) ElapsedRowCopyTime() time.Duration { + return time.Now().Sub(this.RowCopyStartTime) +} + +// GetTotalRowsCopied returns the accurate number of rows being copied (affected) +// This is not exactly the same as the rows being iterated via chunks, but potentially close enough +func (this *MigrationContext) GetTotalRowsCopied() int64 { + return atomic.LoadInt64(&this.TotalRowsCopied) +} diff --git a/go/binlog/binlog_dml_event.go b/go/binlog/binlog_dml_event.go new file mode 100644 index 0000000..069b0a5 --- /dev/null +++ b/go/binlog/binlog_dml_event.go @@ -0,0 +1,66 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package binlog + +import ( + "fmt" + "github.com/github/gh-osc/go/sql" + "strings" +) + +type EventDML string + +const ( + NotDML EventDML = "NoDML" + InsertDML = "Insert" + UpdateDML = "Update" + DeleteDML = "Delete" +) + +func ToEventDML(description string) EventDML { + // description can be a statement (`UPDATE my_table ...`) or a RBR event name (`UpdateRowsEventV2`) + description = strings.TrimSpace(strings.Split(description, " ")[0]) + switch strings.ToLower(description) { + case "insert": + return InsertDML + case "update": + return UpdateDML + case "delete": + return DeleteDML + } + if strings.HasPrefix(description, "WriteRows") { + return InsertDML + } + if strings.HasPrefix(description, "UpdateRows") { + return UpdateDML + } + if strings.HasPrefix(description, "DeleteRows") { + return DeleteDML + } + return NotDML +} + +// BinlogDMLEvent is a binary log rows (DML) event entry, with data +type BinlogDMLEvent struct { + DatabaseName string + TableName string + DML EventDML + WhereColumnValues *sql.ColumnValues + NewColumnValues *sql.ColumnValues +} + +func NewBinlogDMLEvent(databaseName, tableName string, dml EventDML) *BinlogDMLEvent { + event := &BinlogDMLEvent{ + DatabaseName: databaseName, + TableName: tableName, + DML: dml, + } + return event +} + +func (this *BinlogDMLEvent) String() string { + return fmt.Sprintf("[%+v on %s:%s]", this.DML, this.DatabaseName, this.TableName) +} diff --git a/go/binlog/binlog_entry.go b/go/binlog/binlog_entry.go index 0610c70..961c530 100644 --- a/go/binlog/binlog_entry.go +++ b/go/binlog/binlog_entry.go @@ -5,27 +5,43 @@ package binlog +import ( + "fmt" + "github.com/github/gh-osc/go/mysql" +) + // BinlogEntry describes an entry in the binary log type BinlogEntry struct { - LogPos uint64 - EndLogPos uint64 - StatementType string // INSERT, UPDATE, DELETE - DatabaseName string - TableName string - PositionalColumns map[uint64]interface{} + Coordinates mysql.BinlogCoordinates + EndLogPos uint64 + + DmlEvent *BinlogDMLEvent } // NewBinlogEntry creates an empty, ready to go BinlogEntry object -func NewBinlogEntry() *BinlogEntry { - binlogEntry := &BinlogEntry{} - binlogEntry.PositionalColumns = make(map[uint64]interface{}) +func NewBinlogEntry(logFile string, logPos uint64) *BinlogEntry { + binlogEntry := &BinlogEntry{ + Coordinates: mysql.BinlogCoordinates{LogFile: logFile, LogPos: int64(logPos)}, + } + return binlogEntry +} + +// NewBinlogEntry creates an empty, ready to go BinlogEntry object +func NewBinlogEntryAt(coordinates mysql.BinlogCoordinates) *BinlogEntry { + binlogEntry := &BinlogEntry{ + Coordinates: coordinates, + } return binlogEntry } // Duplicate creates and returns a new binlog entry, with some of the attributes pre-assigned func (this *BinlogEntry) Duplicate() *BinlogEntry { - binlogEntry := NewBinlogEntry() - binlogEntry.LogPos = this.LogPos + binlogEntry := NewBinlogEntry(this.Coordinates.LogFile, uint64(this.Coordinates.LogPos)) binlogEntry.EndLogPos = this.EndLogPos return binlogEntry } + +// Duplicate creates and returns a new binlog entry, with some of the attributes pre-assigned +func (this *BinlogEntry) String() string { + return fmt.Sprintf("[BinlogEntry at %+v; dml:%+v]", this.Coordinates, this.DmlEvent) +} diff --git a/go/binlog/binlog_reader.go b/go/binlog/binlog_reader.go index 3392460..6d07320 100644 --- a/go/binlog/binlog_reader.go +++ b/go/binlog/binlog_reader.go @@ -8,5 +8,5 @@ package binlog // BinlogReader is a general interface whose implementations can choose their methods of reading // a binary log file and parsing it into binlog entries type BinlogReader interface { - ReadEntries(logFile string, startPos uint64, stopPos uint64) (entries [](*BinlogEntry), err error) + StreamEvents(canStopStreaming func() bool, entriesChannel chan<- *BinlogEntry) error } diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index 3e6db9d..4a0b586 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -7,11 +7,10 @@ package binlog import ( "fmt" - "os" - "reflect" - "strings" "github.com/github/gh-osc/go/mysql" + "github.com/github/gh-osc/go/sql" + "github.com/outbrain/golib/log" gomysql "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/replication" @@ -24,18 +23,24 @@ const ( ) type GoMySQLReader struct { - connectionConfig *mysql.ConnectionConfig - binlogSyncer *replication.BinlogSyncer + connectionConfig *mysql.ConnectionConfig + binlogSyncer *replication.BinlogSyncer + binlogStreamer *replication.BinlogStreamer + tableMap map[uint64]string + currentCoordinates mysql.BinlogCoordinates } func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *GoMySQLReader, err error) { binlogReader = &GoMySQLReader{ - connectionConfig: connectionConfig, + connectionConfig: connectionConfig, + tableMap: make(map[uint64]string), + currentCoordinates: mysql.BinlogCoordinates{}, + binlogStreamer: nil, } binlogReader.binlogSyncer = replication.NewBinlogSyncer(serverId, "mysql") // Register slave, the MySQL master is at 127.0.0.1:3306, with user root and an empty password - err = binlogReader.binlogSyncer.RegisterSlave(connectionConfig.Hostname, uint16(connectionConfig.Port), connectionConfig.User, connectionConfig.Password) + err = binlogReader.binlogSyncer.RegisterSlave(connectionConfig.Key.Hostname, uint16(connectionConfig.Key.Port), connectionConfig.User, connectionConfig.Password) if err != nil { return binlogReader, err } @@ -43,57 +48,75 @@ func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *G return binlogReader, err } -func (this *GoMySQLReader) isDMLEvent(event *replication.BinlogEvent) bool { - eventType := event.Header.EventType.String() - if strings.HasPrefix(eventType, "WriteRows") { - return true - } - if strings.HasPrefix(eventType, "UpdateRows") { - return true - } - if strings.HasPrefix(eventType, "DeleteRows") { - return true - } - return false -} - -// ReadEntries will read binlog entries from parsed text output of `mysqlbinlog` utility -func (this *GoMySQLReader) ReadEntries(logFile string, startPos uint64, stopPos uint64) (entries [](*BinlogEntry), err error) { +// ConnectBinlogStreamer +func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordinates) (err error) { + this.currentCoordinates = coordinates // Start sync with sepcified binlog file and position - streamer, err := this.binlogSyncer.StartSync(gomysql.Position{logFile, uint32(startPos)}) - if err != nil { - return entries, err - } + this.binlogStreamer, err = this.binlogSyncer.StartSync(gomysql.Position{coordinates.LogFile, uint32(coordinates.LogPos)}) - for { - ev, err := streamer.GetEvent() - if err != nil { - return entries, err - } - if rowsEvent, ok := ev.Event.(*replication.RowsEvent); ok { - if true { - fmt.Println(ev.Header.EventType) - fmt.Println(len(rowsEvent.Rows)) - - for _, rows := range rowsEvent.Rows { - for j, d := range rows { - if _, ok := d.([]byte); ok { - fmt.Print(fmt.Sprintf("yesbin %d:%q, %+v\n", j, d, reflect.TypeOf(d))) - } else { - fmt.Print(fmt.Sprintf("notbin %d:%#v, %+v\n", j, d, reflect.TypeOf(d))) - } - } - fmt.Println("---") - } - } else { - ev.Dump(os.Stdout) - } - // TODO : convert to entries - // need to parse multi-row entries - // insert & delete are just one row per db orw - // update: where-row_>values-row, repeating - } - } - log.Debugf("done") - return entries, err + return err +} + +// StreamEvents +func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesChannel chan<- *BinlogEntry) error { + for { + if canStopStreaming() { + break + } + ev, err := this.binlogStreamer.GetEvent() + if err != nil { + return err + } + this.currentCoordinates.LogPos = int64(ev.Header.LogPos) + if rotateEvent, ok := ev.Event.(*replication.RotateEvent); ok { + this.currentCoordinates.LogFile = string(rotateEvent.NextLogName) + log.Infof("rotate to next log name: %s", rotateEvent.NextLogName) + } else if tableMapEvent, ok := ev.Event.(*replication.TableMapEvent); ok { + // Actually not being used, since Table is available in RowsEvent. + // Keeping this here in case I'm wrong about this. Sometime in the near + // future I should remove this. + this.tableMap[tableMapEvent.TableID] = string(tableMapEvent.Table) + } else if rowsEvent, ok := ev.Event.(*replication.RowsEvent); ok { + dml := ToEventDML(ev.Header.EventType.String()) + if dml == NotDML { + return fmt.Errorf("Unknown DML type: %s", ev.Header.EventType.String()) + } + for i, row := range rowsEvent.Rows { + if dml == UpdateDML && i%2 == 1 { + // An update has two rows (WHERE+SET) + // We do both at the same time + continue + } + binlogEntry := NewBinlogEntryAt(this.currentCoordinates) + binlogEntry.DmlEvent = NewBinlogDMLEvent( + string(rowsEvent.Table.Schema), + string(rowsEvent.Table.Table), + dml, + ) + switch dml { + case InsertDML: + { + binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(row) + } + case UpdateDML: + { + binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row) + binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(rowsEvent.Rows[i+1]) + } + case DeleteDML: + { + binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row) + } + } + // The channel will do the throttling. Whoever is reding from the channel + // decides whether action is taken sycnhronously (meaning we wait before + // next iteration) or asynchronously (we keep pushing more events) + // In reality, reads will be synchronous + entriesChannel <- binlogEntry + } + } + } + log.Debugf("done streaming events") + + return nil } diff --git a/go/binlog/mysqlbinlog_reader.go b/go/binlog/mysqlbinlog_reader.go index 9ba02ae..87e408a 100644 --- a/go/binlog/mysqlbinlog_reader.go +++ b/go/binlog/mysqlbinlog_reader.go @@ -12,7 +12,7 @@ import ( "path" "regexp" "strconv" - "strings" + // "strings" "github.com/github/gh-osc/go/os" "github.com/outbrain/golib/log" @@ -78,7 +78,7 @@ func (this *MySQLBinlogReader) ReadEntries(logFile string, startPos uint64, stop return entries, log.Errore(err) } - chunkEntries, err := parseEntries(bufio.NewScanner(bytes.NewReader(entriesBytes))) + chunkEntries, err := parseEntries(bufio.NewScanner(bytes.NewReader(entriesBytes)), logFile) if err != nil { return entries, log.Errore(err) } @@ -103,41 +103,38 @@ func searchForStartPosOrStatement(scanner *bufio.Scanner, binlogEntry *BinlogEnt return InvalidState, binlogEntry, fmt.Errorf("Expected startLogPos %+v to equal previous endLogPos %+v", startLogPos, previousEndLogPos) } nextBinlogEntry = binlogEntry - if binlogEntry.LogPos != 0 && binlogEntry.StatementType != "" { + if binlogEntry.Coordinates.LogPos != 0 && binlogEntry.DmlEvent != nil { // Current entry is already a true entry, with startpos and with statement - nextBinlogEntry = NewBinlogEntry() + nextBinlogEntry = NewBinlogEntry(binlogEntry.Coordinates.LogFile, startLogPos) } - - nextBinlogEntry.LogPos = startLogPos return ExpectEndLogPosState, nextBinlogEntry, nil } onStatementEntry := func(submatch []string) (BinlogEntryState, *BinlogEntry, error) { nextBinlogEntry = binlogEntry - if binlogEntry.LogPos != 0 && binlogEntry.StatementType != "" { + if binlogEntry.Coordinates.LogPos != 0 && binlogEntry.DmlEvent != nil { // Current entry is already a true entry, with startpos and with statement nextBinlogEntry = binlogEntry.Duplicate() } - - nextBinlogEntry.StatementType = strings.Split(submatch[1], " ")[0] - nextBinlogEntry.DatabaseName = submatch[2] - nextBinlogEntry.TableName = submatch[3] + nextBinlogEntry.DmlEvent = NewBinlogDMLEvent(submatch[2], submatch[3], ToEventDML(submatch[1])) return ExpectTokenState, nextBinlogEntry, nil } - onPositionalColumn := func(submatch []string) (BinlogEntryState, *BinlogEntry, error) { - columnIndex, _ := strconv.ParseUint(submatch[1], 10, 64) - if _, found := binlogEntry.PositionalColumns[columnIndex]; found { - return InvalidState, binlogEntry, fmt.Errorf("Positional column %+v found more than once in %+v, statement=%+v", columnIndex, binlogEntry.LogPos, binlogEntry.StatementType) - } - columnValue := submatch[2] - columnValue = strings.TrimPrefix(columnValue, "'") - columnValue = strings.TrimSuffix(columnValue, "'") - binlogEntry.PositionalColumns[columnIndex] = columnValue + // Defuncting the following: - return SearchForStartPosOrStatementState, binlogEntry, nil - } + // onPositionalColumn := func(submatch []string) (BinlogEntryState, *BinlogEntry, error) { + // columnIndex, _ := strconv.ParseUint(submatch[1], 10, 64) + // if _, found := binlogEntry.PositionalColumns[columnIndex]; found { + // return InvalidState, binlogEntry, fmt.Errorf("Positional column %+v found more than once in %+v, statement=%+v", columnIndex, binlogEntry.LogPos, binlogEntry.DmlEvent.DML) + // } + // columnValue := submatch[2] + // columnValue = strings.TrimPrefix(columnValue, "'") + // columnValue = strings.TrimSuffix(columnValue, "'") + // binlogEntry.PositionalColumns[columnIndex] = columnValue + // + // return SearchForStartPosOrStatementState, binlogEntry, nil + // } line := scanner.Text() if submatch := startEntryRegexp.FindStringSubmatch(line); len(submatch) > 1 { @@ -150,7 +147,7 @@ func searchForStartPosOrStatement(scanner *bufio.Scanner, binlogEntry *BinlogEnt return onStatementEntry(submatch) } if submatch := positionalColumnRegexp.FindStringSubmatch(line); len(submatch) > 1 { - return onPositionalColumn(submatch) + // Defuncting return onPositionalColumn(submatch) } // Haven't found a match return SearchForStartPosOrStatementState, binlogEntry, nil @@ -165,7 +162,7 @@ func expectEndLogPos(scanner *bufio.Scanner, binlogEntry *BinlogEntry) (nextStat binlogEntry.EndLogPos, _ = strconv.ParseUint(submatch[1], 10, 64) return SearchForStartPosOrStatementState, nil } - return InvalidState, fmt.Errorf("Expected to find end_log_pos following pos %+v", binlogEntry.LogPos) + return InvalidState, fmt.Errorf("Expected to find end_log_pos following pos %+v", binlogEntry.Coordinates.LogPos) } // automaton step: a not-strictly-required but good-to-have-around validation that @@ -175,26 +172,26 @@ func expectToken(scanner *bufio.Scanner, binlogEntry *BinlogEntry) (nextState Bi if submatch := tokenRegxp.FindStringSubmatch(line); len(submatch) > 1 { return SearchForStartPosOrStatementState, nil } - return InvalidState, fmt.Errorf("Expected to find token following pos %+v", binlogEntry.LogPos) + return InvalidState, fmt.Errorf("Expected to find token following pos %+v", binlogEntry.Coordinates.LogPos) } // parseEntries will parse output of `mysqlbinlog --verbose --base64-output=DECODE-ROWS` // It issues an automaton / state machine to do its thang. -func parseEntries(scanner *bufio.Scanner) (entries [](*BinlogEntry), err error) { - binlogEntry := NewBinlogEntry() +func parseEntries(scanner *bufio.Scanner, logFile string) (entries [](*BinlogEntry), err error) { + binlogEntry := NewBinlogEntry(logFile, 0) var state BinlogEntryState = SearchForStartPosOrStatementState var endLogPos uint64 appendBinlogEntry := func() { - if binlogEntry.LogPos == 0 { + if binlogEntry.Coordinates.LogPos == 0 { return } - if binlogEntry.StatementType == "" { + if binlogEntry.DmlEvent == nil { return } entries = append(entries, binlogEntry) log.Debugf("entry: %+v", *binlogEntry) - fmt.Println(fmt.Sprintf("%s `%s`.`%s`", binlogEntry.StatementType, binlogEntry.DatabaseName, binlogEntry.TableName)) + fmt.Println(fmt.Sprintf("%s `%s`.`%s`", binlogEntry.DmlEvent.DML, binlogEntry.DmlEvent.DatabaseName, binlogEntry.DmlEvent.TableName)) } for scanner.Scan() { switch state { diff --git a/go/cmd/gh-osc/main.go b/go/cmd/gh-osc/main.go index 2cba72a..de7f976 100644 --- a/go/cmd/gh-osc/main.go +++ b/go/cmd/gh-osc/main.go @@ -11,15 +11,12 @@ import ( "os" "github.com/github/gh-osc/go/base" - "github.com/github/gh-osc/go/binlog" "github.com/github/gh-osc/go/logic" - "github.com/github/gh-osc/go/mysql" "github.com/outbrain/golib/log" ) // main is the application's entry point. It will either spawn a CLI or HTTP itnerfaces. func main() { - var connectionConfig mysql.ConnectionConfig migrationContext := base.GetMigrationContext() // mysqlBasedir := flag.String("mysql-basedir", "", "the --basedir config for MySQL (auto-detected if not given)") @@ -27,15 +24,19 @@ func main() { internalExperiment := flag.Bool("internal-experiment", false, "issue an internal experiment") binlogFile := flag.String("binlog-file", "", "Name of binary log file") - flag.StringVar(&connectionConfig.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)") - flag.IntVar(&connectionConfig.Port, "port", 3306, "MySQL port (preferably a replica, not the master)") - flag.StringVar(&connectionConfig.User, "user", "root", "MySQL user") - flag.StringVar(&connectionConfig.Password, "password", "", "MySQL password") + flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)") + flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)") + flag.StringVar(&migrationContext.InspectorConnectionConfig.User, "user", "root", "MySQL user") + flag.StringVar(&migrationContext.InspectorConnectionConfig.Password, "password", "", "MySQL password") flag.StringVar(&migrationContext.DatabaseName, "database", "", "database name (mandatory)") flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)") flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory)") flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)") + flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica") + + flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration") + flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists") quiet := flag.Bool("quiet", false, "quiet") verbose := flag.Bool("verbose", false, "verbose") @@ -78,19 +79,20 @@ func main() { log.Info("starting gh-osc") if *internalExperiment { - log.Debug("starting experiment") - var binlogReader binlog.BinlogReader - var err error + log.Debug("starting experiment with %+v", *binlogFile) //binlogReader = binlog.NewMySQLBinlogReader(*mysqlBasedir, *mysqlDatadir) - binlogReader, err = binlog.NewGoMySQLReader(&connectionConfig) - if err != nil { - log.Fatale(err) - } - binlogReader.ReadEntries(*binlogFile, 0, 0) - return + // binlogReader, err := binlog.NewGoMySQLReader(migrationContext.InspectorConnectionConfig) + // if err != nil { + // log.Fatale(err) + // } + // if err := binlogReader.ConnectBinlogStreamer(mysql.BinlogCoordinates{LogFile: *binlogFile, LogPos: 0}); err != nil { + // log.Fatale(err) + // } + // binlogReader.StreamEvents(func() bool { return false }) + // return } - migrator := logic.NewMigrator(&connectionConfig) + migrator := logic.NewMigrator() err := migrator.Migrate() if err != nil { log.Fatale(err) diff --git a/go/logic/applier.go b/go/logic/applier.go new file mode 100644 index 0000000..b4f63a9 --- /dev/null +++ b/go/logic/applier.go @@ -0,0 +1,414 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package logic + +import ( + gosql "database/sql" + "fmt" + "sync/atomic" + "time" + + "github.com/github/gh-osc/go/base" + "github.com/github/gh-osc/go/mysql" + "github.com/github/gh-osc/go/sql" + + "github.com/outbrain/golib/log" + "github.com/outbrain/golib/sqlutils" +) + +const ( + heartbeatIntervalSeconds = 1 +) + +// Applier reads data from the read-MySQL-server (typically a replica, but can be the master) +// It is used for gaining initial status and structure, and later also follow up on progress and changelog +type Applier struct { + connectionConfig *mysql.ConnectionConfig + db *gosql.DB + migrationContext *base.MigrationContext +} + +func NewApplier() *Applier { + return &Applier{ + connectionConfig: base.GetMigrationContext().MasterConnectionConfig, + migrationContext: base.GetMigrationContext(), + } +} + +func (this *Applier) InitDBConnections() (err error) { + ApplierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) + if this.db, _, err = sqlutils.GetDB(ApplierUri); err != nil { + return err + } + if err := this.validateConnection(); err != nil { + return err + } + return nil +} + +// validateConnection issues a simple can-connect to MySQL +func (this *Applier) validateConnection() error { + query := `select @@global.port` + var port int + if err := this.db.QueryRow(query).Scan(&port); err != nil { + return err + } + if port != this.connectionConfig.Key.Port { + return fmt.Errorf("Unexpected database port reported: %+v", port) + } + log.Infof("connection validated on %+v", this.connectionConfig.Key) + return nil +} + +// CreateGhostTable creates the ghost table on the master +func (this *Applier) CreateGhostTable() error { + query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetGhostTableName()), + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.OriginalTableName), + ) + log.Infof("Creating ghost table %s.%s", + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetGhostTableName()), + ) + if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { + return err + } + log.Infof("Ghost table created") + return nil +} + +// CreateGhostTable creates the ghost table on the master +func (this *Applier) AlterGhost() error { + query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetGhostTableName()), + this.migrationContext.AlterStatement, + ) + log.Infof("Altering ghost table %s.%s", + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetGhostTableName()), + ) + log.Debugf("ALTER statement: %s", query) + if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { + return err + } + log.Infof("Ghost table altered") + return nil +} + +// CreateChangelogTable creates the changelog table on the master +func (this *Applier) CreateChangelogTable() error { + query := fmt.Sprintf(`create /* gh-osc */ table %s.%s ( + id int auto_increment, + last_update timestamp not null DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + hint varchar(64) charset ascii not null, + value varchar(64) charset ascii not null, + primary key(id), + unique key hint_uidx(hint) + ) auto_increment=2 + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + log.Infof("Creating changelog table %s.%s", + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { + return err + } + log.Infof("Changelog table created") + return nil +} + +// DropChangelogTable drops the changelog table on the master +func (this *Applier) DropChangelogTable() error { + query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + log.Infof("Droppping changelog table %s.%s", + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { + return err + } + log.Infof("Changelog table dropped") + return nil +} + +// WriteChangelog writes a value to the changelog table. +// It returns the hint as given, for convenience +func (this *Applier) WriteChangelog(hint, value string) (string, error) { + query := fmt.Sprintf(` + insert /* gh-osc */ into %s.%s + (id, hint, value) + values + (NULL, ?, ?) + on duplicate key update + last_update=NOW(), + value=VALUES(value) + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + _, err := sqlutils.Exec(this.db, query, hint, value) + return hint, err +} + +// InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. +// This is done asynchronously +func (this *Applier) InitiateHeartbeat() { + go func() { + numSuccessiveFailures := 0 + query := fmt.Sprintf(` + insert /* gh-osc */ into %s.%s + (id, hint, value) + values + (1, 'heartbeat', ?) + on duplicate key update + last_update=NOW(), + value=VALUES(value) + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + injectHeartbeat := func() error { + if _, err := sqlutils.ExecNoPrepare(this.db, query, time.Now().Format(time.RFC3339)); err != nil { + numSuccessiveFailures++ + if numSuccessiveFailures > this.migrationContext.MaxRetries() { + return log.Errore(err) + } + } else { + numSuccessiveFailures = 0 + } + return nil + } + injectHeartbeat() + + heartbeatTick := time.Tick(time.Duration(heartbeatIntervalSeconds) * time.Second) + for range heartbeatTick { + // Generally speaking, we would issue a goroutine, but I'd actually rather + // have this blocked rather than spam the master in the event something + // goes wrong + if err := injectHeartbeat(); err != nil { + return + } + } + }() +} + +// ReadMigrationMinValues +func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { + log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) + query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + if err != nil { + return err + } + rows, err := this.db.Query(query) + if err != nil { + return err + } + for rows.Next() { + this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(len(uniqueKey.Columns)) + if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil { + return err + } + } + log.Infof("Migration min values: [%s]", this.migrationContext.MigrationRangeMinValues) + return err +} + +// ReadMigrationMinValues +func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { + log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) + query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + if err != nil { + return err + } + rows, err := this.db.Query(query) + if err != nil { + return err + } + for rows.Next() { + this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(len(uniqueKey.Columns)) + if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil { + return err + } + } + log.Infof("Migration max values: [%s]", this.migrationContext.MigrationRangeMaxValues) + return err +} + +func (this *Applier) ReadMigrationRangeValues() error { + if err := this.ReadMigrationMinValues(this.migrationContext.UniqueKey); err != nil { + return err + } + if err := this.ReadMigrationMaxValues(this.migrationContext.UniqueKey); err != nil { + return err + } + return nil +} + +// __unused_IterationIsComplete lets us know when the copy-iteration phase is complete, i.e. +// we've exhausted all rows +func (this *Applier) __unused_IterationIsComplete() (bool, error) { + if !this.migrationContext.HasMigrationRange() { + return false, nil + } + if this.migrationContext.MigrationIterationRangeMinValues == nil { + return false, nil + } + args := sqlutils.Args() + compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) + if err != nil { + return false, err + } + args = append(args, explodedArgs...) + compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) + if err != nil { + return false, err + } + args = append(args, explodedArgs...) + query := fmt.Sprintf(` + select /* gh-osc IterationIsComplete */ 1 + from %s.%s + where (%s) and (%s) + limit 1 + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.OriginalTableName), + compareWithIterationRangeStart, + compareWithRangeEnd, + ) + + moreRowsFound := false + err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { + moreRowsFound = true + return nil + }, args...) + if err != nil { + return false, err + } + return !moreRowsFound, nil +} + +// CalculateNextIterationRangeEndValues reads the next-iteration-range-end unique key values, +// which will be used for copying the next chunk of rows. Ir returns "false" if there is +// no further chunk to work through, i.e. we're past the last chunk and are done with +// itrating the range (and this done with copying row chunks) +func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange bool, err error) { + this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationIterationRangeMaxValues + if this.migrationContext.MigrationIterationRangeMinValues == nil { + this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationRangeMinValues + } + query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( + this.migrationContext.DatabaseName, + this.migrationContext.OriginalTableName, + this.migrationContext.UniqueKey.Columns, + this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), + this.migrationContext.MigrationRangeMaxValues.AbstractValues(), + this.migrationContext.ChunkSize, + fmt.Sprintf("iteration:%d", this.migrationContext.Iteration), + ) + if err != nil { + return hasFurtherRange, err + } + rows, err := this.db.Query(query, explodedArgs...) + if err != nil { + return hasFurtherRange, err + } + iterationRangeMaxValues := sql.NewColumnValues(len(this.migrationContext.UniqueKey.Columns)) + for rows.Next() { + if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil { + return hasFurtherRange, err + } + hasFurtherRange = true + } + if !hasFurtherRange { + log.Debugf("Iteration complete: cannot find iteration end") + return hasFurtherRange, nil + } + this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues + log.Debugf( + "column values: [%s]..[%s]; iteration: %d; chunk-size: %d", + this.migrationContext.MigrationIterationRangeMinValues, + this.migrationContext.MigrationIterationRangeMaxValues, + this.migrationContext.Iteration, + this.migrationContext.ChunkSize, + ) + return hasFurtherRange, nil +} + +func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected int64, duration time.Duration, err error) { + startTime := time.Now() + chunkSize = atomic.LoadInt64(&this.migrationContext.ChunkSize) + + query, explodedArgs, err := sql.BuildRangeInsertPreparedQuery( + this.migrationContext.DatabaseName, + this.migrationContext.OriginalTableName, + this.migrationContext.GetGhostTableName(), + this.migrationContext.UniqueKey.Columns, + this.migrationContext.UniqueKey.Name, + this.migrationContext.UniqueKey.Columns, + this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), + this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), + this.migrationContext.Iteration == 0, + this.migrationContext.IsTransactionalTable(), + ) + if err != nil { + return chunkSize, rowsAffected, duration, err + } + sqlResult, err := sqlutils.Exec(this.db, query, explodedArgs...) + if err != nil { + return chunkSize, rowsAffected, duration, err + } + rowsAffected, _ = sqlResult.RowsAffected() + duration = time.Now().Sub(startTime) + this.WriteChangelog( + fmt.Sprintf("copy iteration %d", this.migrationContext.Iteration), + fmt.Sprintf("chunk: %d; affected: %d; duration: %d", chunkSize, rowsAffected, duration), + ) + log.Debugf( + "Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d", + this.migrationContext.MigrationIterationRangeMinValues, + this.migrationContext.MigrationIterationRangeMaxValues, + this.migrationContext.Iteration, + chunkSize) + return chunkSize, rowsAffected, duration, nil +} + +// LockTables +func (this *Applier) LockTables() error { + query := fmt.Sprintf(`lock /* gh-osc */ tables %s.%s write, %s.%s write, %s.%s write`, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.OriginalTableName), + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetGhostTableName()), + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + log.Infof("Locking tables") + if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { + return err + } + log.Infof("Tables locked") + return nil +} + +// UnlockTables +func (this *Applier) UnlockTables() error { + query := `unlock /* gh-osc */ tables` + log.Infof("Unlocking tables") + if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { + return err + } + log.Infof("Tables unlocked") + return nil +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 8839849..31f74ec 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -26,15 +26,15 @@ type Inspector struct { migrationContext *base.MigrationContext } -func NewInspector(connectionConfig *mysql.ConnectionConfig) *Inspector { +func NewInspector() *Inspector { return &Inspector{ - connectionConfig: connectionConfig, + connectionConfig: base.GetMigrationContext().InspectorConnectionConfig, migrationContext: base.GetMigrationContext(), } } func (this *Inspector) InitDBConnections() (err error) { - inspectorUri := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", this.connectionConfig.User, this.connectionConfig.Password, this.connectionConfig.Hostname, this.connectionConfig.Port, this.migrationContext.DatabaseName) + inspectorUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) if this.db, _, err = sqlutils.GetDB(inspectorUri); err != nil { return err } @@ -47,9 +47,16 @@ func (this *Inspector) InitDBConnections() (err error) { if err := this.validateBinlogs(); err != nil { return err } + return nil +} + +func (this *Inspector) ValidateOriginalTable() (err error) { if err := this.validateTable(); err != nil { return err } + if err := this.validateTableForeignKeys(); err != nil { + return err + } if this.migrationContext.CountTableRows { if err := this.countTableRows(); err != nil { return err @@ -59,32 +66,31 @@ func (this *Inspector) InitDBConnections() (err error) { return err } } - return nil } -func (this *Inspector) InspectTables() (err error) { - uniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName) +func (this *Inspector) InspectOriginalTable() (uniqueKeys [](*sql.UniqueKey), err error) { + uniqueKeys, err = this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName) if err != nil { - return err + return uniqueKeys, err } if len(uniqueKeys) == 0 { - return fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") + return uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") } - return nil + return uniqueKeys, err } // validateConnection issues a simple can-connect to MySQL func (this *Inspector) validateConnection() error { - query := `select @@port` + query := `select @@global.port` var port int if err := this.db.QueryRow(query).Scan(&port); err != nil { return err } - if port != this.connectionConfig.Port { + if port != this.connectionConfig.Key.Port { return fmt.Errorf("Unexpected database port reported: %+v", port) } - log.Infof("connection validated on port %+v", port) + log.Infof("connection validated on %+v", this.connectionConfig.Key) return nil } @@ -116,7 +122,7 @@ func (this *Inspector) validateGrants() error { return nil }) if err != nil { - return log.Errore(err) + return err } if foundAll { @@ -130,7 +136,7 @@ func (this *Inspector) validateGrants() error { return log.Errorf("User has insufficient privileges for migration.") } -// validateConnection issues a simple can-connect to MySQL +// validateBinlogs checks that binary log configuration is good to go func (this *Inspector) validateBinlogs() error { query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format` var hasBinaryLogs, logSlaveUpdates bool @@ -138,10 +144,10 @@ func (this *Inspector) validateBinlogs() error { return err } if !hasBinaryLogs { - return fmt.Errorf("%s:%d must have binary logs enabled", this.connectionConfig.Hostname, this.connectionConfig.Port) + return fmt.Errorf("%s:%d must have binary logs enabled", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) } if !logSlaveUpdates { - return fmt.Errorf("%s:%d must have log_slave_updates enabled", this.connectionConfig.Hostname, this.connectionConfig.Port) + return fmt.Errorf("%s:%d must have log_slave_updates enabled", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) } if this.migrationContext.RequiresBinlogFormatChange() { query := fmt.Sprintf(`show /* gh-osc */ slave hosts`) @@ -151,12 +157,12 @@ func (this *Inspector) validateBinlogs() error { return nil }) if err != nil { - return log.Errore(err) + return err } if countReplicas > 0 { - return fmt.Errorf("%s:%d has %s binlog_format, but I'm too scared to change it to ROW because it has replicas. Bailing out", this.connectionConfig.Hostname, this.connectionConfig.Port, this.migrationContext.OriginalBinlogFormat) + return fmt.Errorf("%s:%d has %s binlog_format, but I'm too scared to change it to ROW because it has replicas. Bailing out", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogFormat) } - log.Infof("%s:%d has %s binlog_format. I will change it to ROW for the duration of this migration.", this.connectionConfig.Hostname, this.connectionConfig.Port, this.migrationContext.OriginalBinlogFormat) + log.Infof("%s:%d has %s binlog_format. I will change it to ROW for the duration of this migration.", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogFormat) } query = `select @@global.binlog_row_image` if err := this.db.QueryRow(query).Scan(&this.migrationContext.OriginalBinlogRowImage); err != nil { @@ -164,7 +170,7 @@ func (this *Inspector) validateBinlogs() error { this.migrationContext.OriginalBinlogRowImage = "" } - log.Infof("binary logs validated on %s:%d", this.connectionConfig.Hostname, this.connectionConfig.Port) + log.Infof("binary logs validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) return nil } @@ -185,7 +191,7 @@ func (this *Inspector) validateTable() error { return nil }) if err != nil { - return log.Errore(err) + return err } if !tableFound { return log.Errorf("Cannot find table %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) @@ -195,6 +201,37 @@ func (this *Inspector) validateTable() error { return nil } +func (this *Inspector) validateTableForeignKeys() error { + query := ` + SELECT COUNT(*) AS num_foreign_keys + FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE + WHERE + REFERENCED_TABLE_NAME IS NOT NULL + AND ((TABLE_SCHEMA=? AND TABLE_NAME=?) + OR (REFERENCED_TABLE_SCHEMA=? AND REFERENCED_TABLE_NAME=?) + ) + ` + numForeignKeys := 0 + err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { + numForeignKeys = rowMap.GetInt("num_foreign_keys") + + return nil + }, + this.migrationContext.DatabaseName, + this.migrationContext.OriginalTableName, + this.migrationContext.DatabaseName, + this.migrationContext.OriginalTableName, + ) + if err != nil { + return err + } + if numForeignKeys > 0 { + return log.Errorf("Found %d foreign keys on %s.%s. Foreign keys are not supported. Bailing out", numForeignKeys, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + } + log.Debugf("Validated no foreign keys exist on table") + return nil +} + func (this *Inspector) estimateTableRowsViaExplain() error { query := fmt.Sprintf(`explain select /* gh-osc */ * from %s.%s where 1=1`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) @@ -207,7 +244,7 @@ func (this *Inspector) estimateTableRowsViaExplain() error { return nil }) if err != nil { - return log.Errore(err) + return err } if !outputFound { return log.Errorf("Cannot run EXPLAIN on %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) @@ -227,6 +264,29 @@ func (this *Inspector) countTableRows() error { return nil } +func (this *Inspector) getTableColumns(databaseName, tableName string) (columns sql.ColumnList, err error) { + query := fmt.Sprintf(` + show columns from %s.%s + `, + sql.EscapeName(databaseName), + sql.EscapeName(tableName), + ) + err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { + columns = append(columns, rowMap.GetString("Field")) + return nil + }) + if err != nil { + return columns, err + } + if len(columns) == 0 { + return columns, log.Errorf("Found 0 columns on %s.%s. Bailing out", + sql.EscapeName(databaseName), + sql.EscapeName(tableName), + ) + } + return columns, nil +} + // getCandidateUniqueKeys investigates a table and returns the list of unique keys // candidate for chunking func (this *Inspector) getCandidateUniqueKeys(tableName string) (uniqueKeys [](*sql.UniqueKey), err error) { @@ -308,7 +368,7 @@ func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err if err != nil { return uniqueKeys, err } - ghostUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.GhostTableName) + ghostUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.GetGhostTableName()) if err != nil { return uniqueKeys, err } @@ -323,3 +383,44 @@ func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err } return uniqueKeys, nil } + +func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { + visitedKeys := mysql.NewInstanceKeyMap() + return getMasterConnectionConfigSafe(this.connectionConfig, this.migrationContext.DatabaseName, visitedKeys) +} + +func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, databaseName string, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { + log.Debugf("Looking for master on %+v", connectionConfig.Key) + + currentUri := connectionConfig.GetDBUri(databaseName) + db, _, err := sqlutils.GetDB(currentUri) + if err != nil { + return nil, err + } + + hasMaster := false + masterConfig = connectionConfig.Duplicate() + err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { + masterKey := mysql.InstanceKey{ + Hostname: rowMap.GetString("Master_Host"), + Port: rowMap.GetInt("Master_Port"), + } + if masterKey.IsValid() { + masterConfig.Key = masterKey + hasMaster = true + } + return nil + }) + if err != nil { + return nil, err + } + if hasMaster { + log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) + if visitedKeys.HasKey(masterConfig.Key) { + return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) + } + visitedKeys.AddKey(masterConfig.Key) + return getMasterConnectionConfigSafe(masterConfig, databaseName, visitedKeys) + } + return masterConfig, nil +} diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 40adab9..451cb96 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -6,28 +6,350 @@ package logic import ( - "github.com/github/gh-osc/go/mysql" + "fmt" + "os" + "sync/atomic" + "time" + + "github.com/github/gh-osc/go/base" + "github.com/github/gh-osc/go/binlog" + + "github.com/outbrain/golib/log" +) + +type ChangelogState string + +const ( + TablesInPlace ChangelogState = "TablesInPlace" + AllEventsUpToLockProcessed = "AllEventsUpToLockProcessed" +) + +type tableWriteFunc func() error + +const ( + applyEventsQueueBuffer = 100 ) // Migrator is the main schema migration flow manager. type Migrator struct { - connectionConfig *mysql.ConnectionConfig inspector *Inspector + applier *Applier + eventsStreamer *EventsStreamer + migrationContext *base.MigrationContext + + tablesInPlace chan bool + rowCopyComplete chan bool + allEventsUpToLockProcessed chan bool + + // copyRowsQueue should not be buffered; if buffered some non-damaging but + // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete + copyRowsQueue chan tableWriteFunc + applyEventsQueue chan tableWriteFunc } -func NewMigrator(connectionConfig *mysql.ConnectionConfig) *Migrator { - return &Migrator{ - connectionConfig: connectionConfig, - inspector: NewInspector(connectionConfig), +func NewMigrator() *Migrator { + migrator := &Migrator{ + migrationContext: base.GetMigrationContext(), + tablesInPlace: make(chan bool), + rowCopyComplete: make(chan bool), + allEventsUpToLockProcessed: make(chan bool), + + copyRowsQueue: make(chan tableWriteFunc), + applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), } + migrator.migrationContext.IsThrottled = func() bool { + return migrator.shouldThrottle() + } + return migrator } -func (this *Migrator) Migrate() error { +func (this *Migrator) shouldThrottle() bool { + lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) + + shouldThrottle := false + if time.Duration(lag) > time.Duration(this.migrationContext.MaxLagMillisecondsThrottleThreshold)*time.Millisecond { + shouldThrottle = true + } else if this.migrationContext.ThrottleFlagFile != "" { + if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil { + //Throttle file defined and exists! + shouldThrottle = true + } + } + return shouldThrottle +} + +func (this *Migrator) canStopStreaming() bool { + return false +} + +func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { + // Hey, I created the changlog table, I know the type of columns it has! + if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" { + return + } + changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) + switch changelogState { + case TablesInPlace: + { + this.tablesInPlace <- true + } + case AllEventsUpToLockProcessed: + { + this.allEventsUpToLockProcessed <- true + } + default: + { + return fmt.Errorf("Unknown changelog state: %+v", changelogState) + } + } + log.Debugf("---- - - - - - state %+v", changelogState) + return nil +} + +func (this *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { + if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "heartbeat" { + return nil + } + value := dmlEvent.NewColumnValues.StringColumn(3) + heartbeatTime, err := time.Parse(time.RFC3339, value) + if err != nil { + return log.Errore(err) + } + lag := time.Now().Sub(heartbeatTime) + + atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) + + return nil +} + +func (this *Migrator) Migrate() (err error) { + this.migrationContext.StartTime = time.Now() + + this.inspector = NewInspector() if err := this.inspector.InitDBConnections(); err != nil { return err } - if err := this.inspector.InspectTables(); err != nil { + if err := this.inspector.ValidateOriginalTable(); err != nil { return err } + uniqueKeys, err := this.inspector.InspectOriginalTable() + if err != nil { + return err + } + // So far so good, table is accessible and valid. + if this.migrationContext.MasterConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { + return err + } + if this.migrationContext.IsRunningOnMaster() && !this.migrationContext.AllowedRunningOnMaster { + return fmt.Errorf("It seems like this migration attempt to run directly on master. Preferably it would be executed on a replica (and this reduces load from the master). To proceed please provide --allow-on-master") + } + log.Infof("Master found to be %+v", this.migrationContext.MasterConnectionConfig.Key) + + if err := this.initiateStreaming(); err != nil { + return err + } + if err := this.initiateApplier(); err != nil { + return err + } + + log.Debugf("Waiting for tables to be in place") + <-this.tablesInPlace + log.Debugf("Tables are in place") + // Yay! We now know the Ghost and Changelog tables are good to examine! + // When running on replica, this means the replica has those tables. When running + // on master this is always true, of course, and yet it also implies this knowledge + // is in the binlogs. + + this.migrationContext.UniqueKey = uniqueKeys[0] // TODO. Need to wait on replica till the ghost table exists and get shared keys + if err := this.applier.ReadMigrationRangeValues(); err != nil { + return err + } + go this.initiateStatus() + go this.executeWriteFuncs() + go this.iterateChunks() + + log.Debugf("Operating until row copy is complete") + <-this.rowCopyComplete + log.Debugf("Row copy complete") + this.printStatus() + + throttleMigration( + this.migrationContext, + func() { + log.Debugf("throttling before LOCK TABLES") + }, + nil, + func() { + log.Debugf("done throttling") + }, + ) + // TODO retries!! + this.applier.LockTables() + this.applier.WriteChangelog("state", string(AllEventsUpToLockProcessed)) + log.Debugf("Waiting for events up to lock") + <-this.allEventsUpToLockProcessed + log.Debugf("Done waiting for events up to lock") + // TODO retries!! + this.applier.UnlockTables() + + return nil +} + +func (this *Migrator) initiateStatus() error { + this.printStatus() + statusTick := time.Tick(1 * time.Second) + for range statusTick { + go this.printStatus() + } + + return nil +} + +func (this *Migrator) printStatus() { + elapsedTime := this.migrationContext.ElapsedTime() + elapsedSeconds := int64(elapsedTime.Seconds()) + totalRowsCopied := this.migrationContext.GetTotalRowsCopied() + rowsEstimate := this.migrationContext.RowsEstimate + progressPct := 100.0 * float64(totalRowsCopied) / float64(rowsEstimate) + + shouldPrintStatus := false + if elapsedSeconds <= 60 { + shouldPrintStatus = true + } else if progressPct >= 99.0 { + shouldPrintStatus = true + } else if progressPct >= 95.0 { + shouldPrintStatus = (elapsedSeconds%5 == 0) + } else if elapsedSeconds <= 120 { + shouldPrintStatus = (elapsedSeconds%5 == 0) + } else { + shouldPrintStatus = (elapsedSeconds%30 == 0) + } + if !shouldPrintStatus { + return + } + + status := fmt.Sprintf("Copy: %d/%d %.1f%% Backlog: %d/%d Elapsed: %+v(copy), %+v(total) ETA: N/A", + totalRowsCopied, rowsEstimate, progressPct, + len(this.applyEventsQueue), cap(this.applyEventsQueue), + this.migrationContext.ElapsedRowCopyTime(), elapsedTime) + fmt.Println(status) +} + +func (this *Migrator) initiateStreaming() error { + this.eventsStreamer = NewEventsStreamer() + if err := this.eventsStreamer.InitDBConnections(); err != nil { + return err + } + this.eventsStreamer.AddListener( + false, + this.migrationContext.DatabaseName, + this.migrationContext.GetChangelogTableName(), + func(dmlEvent *binlog.BinlogDMLEvent) error { + return this.onChangelogStateEvent(dmlEvent) + }, + ) + this.eventsStreamer.AddListener( + false, + this.migrationContext.DatabaseName, + this.migrationContext.GetChangelogTableName(), + func(dmlEvent *binlog.BinlogDMLEvent) error { + return this.onChangelogHeartbeatEvent(dmlEvent) + }, + ) + go func() { + log.Debugf("Beginning streaming") + this.eventsStreamer.StreamEvents(func() bool { return this.canStopStreaming() }) + }() + return nil +} + +func (this *Migrator) initiateApplier() error { + this.applier = NewApplier() + if err := this.applier.InitDBConnections(); err != nil { + return err + } + if err := this.applier.CreateGhostTable(); err != nil { + log.Errorf("Unable to create ghost table, see further error details. Perhaps a previous migration failed without dropping the table? Bailing out") + return err + } + if err := this.applier.AlterGhost(); err != nil { + log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out") + return err + } + if err := this.applier.CreateChangelogTable(); err != nil { + log.Errorf("Unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") + return err + } + + this.applier.WriteChangelog("state", string(TablesInPlace)) + this.applier.InitiateHeartbeat() + return nil +} + +func (this *Migrator) iterateChunks() error { + this.migrationContext.RowCopyStartTime = time.Now() + terminateRowIteration := func(err error) error { + this.rowCopyComplete <- true + return log.Errore(err) + } + for { + copyRowsFunc := func() error { + hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() + if err != nil { + return terminateRowIteration(err) + } + if !hasFurtherRange { + return terminateRowIteration(nil) + } + _, rowsAffected, _, err := this.applier.ApplyIterationInsertQuery() + if err != nil { + return terminateRowIteration(err) + } + atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected) + this.migrationContext.Iteration++ + return nil + } + this.copyRowsQueue <- copyRowsFunc + } + return nil +} + +func (this *Migrator) executeWriteFuncs() error { + for { + throttleMigration( + this.migrationContext, + func() { + log.Debugf("throttling writes") + }, + nil, + func() { + log.Debugf("done throttling writes") + }, + ) + // We give higher priority to event processing, then secondary priority to + // rowcopy + select { + case applyEventFunc := <-this.applyEventsQueue: + { + retryOperation(applyEventFunc, this.migrationContext.MaxRetries()) + } + default: + { + select { + case copyRowsFunc := <-this.copyRowsQueue: + { + retryOperation(copyRowsFunc, this.migrationContext.MaxRetries()) + } + default: + { + // Hmmmmm... nothing in the queue; no events, but also no row copy. + // This is possible upon load. Let's just sleep it over. + log.Debugf("Getting nothing in the write queue. Sleeping...") + time.Sleep(time.Second) + } + } + } + } + } return nil } diff --git a/go/logic/streamer.go b/go/logic/streamer.go new file mode 100644 index 0000000..176ddf2 --- /dev/null +++ b/go/logic/streamer.go @@ -0,0 +1,162 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package logic + +import ( + gosql "database/sql" + "fmt" + "strings" + + "github.com/github/gh-osc/go/base" + "github.com/github/gh-osc/go/binlog" + "github.com/github/gh-osc/go/mysql" + + "github.com/outbrain/golib/log" + "github.com/outbrain/golib/sqlutils" +) + +type BinlogEventListener struct { + async bool + databaseName string + tableName string + onDmlEvent func(event *binlog.BinlogDMLEvent) error +} + +const ( + EventsChannelBufferSize = 1 +) + +// EventsStreamer reads data from binary logs and streams it on. It acts as a publisher, +// and interested parties may subscribe for per-table events. +type EventsStreamer struct { + connectionConfig *mysql.ConnectionConfig + db *gosql.DB + migrationContext *base.MigrationContext + nextBinlogCoordinates *mysql.BinlogCoordinates + listeners [](*BinlogEventListener) + eventsChannel chan *binlog.BinlogEntry + binlogReader binlog.BinlogReader +} + +func NewEventsStreamer() *EventsStreamer { + return &EventsStreamer{ + connectionConfig: base.GetMigrationContext().InspectorConnectionConfig, + migrationContext: base.GetMigrationContext(), + listeners: [](*BinlogEventListener){}, + eventsChannel: make(chan *binlog.BinlogEntry, EventsChannelBufferSize), + } +} + +func (this *EventsStreamer) AddListener( + async bool, databaseName string, tableName string, onDmlEvent func(event *binlog.BinlogDMLEvent) error) (err error) { + if databaseName == "" { + return fmt.Errorf("Empty database name in AddListener") + } + if tableName == "" { + return fmt.Errorf("Empty table name in AddListener") + } + listener := &BinlogEventListener{ + async: async, + databaseName: databaseName, + tableName: tableName, + onDmlEvent: onDmlEvent, + } + this.listeners = append(this.listeners, listener) + return nil +} + +func (this *EventsStreamer) notifyListeners(binlogEvent *binlog.BinlogDMLEvent) { + for _, listener := range this.listeners { + if strings.ToLower(listener.databaseName) != strings.ToLower(binlogEvent.DatabaseName) { + continue + } + if strings.ToLower(listener.tableName) != strings.ToLower(binlogEvent.TableName) { + continue + } + onDmlEvent := listener.onDmlEvent + if listener.async { + go func() { + onDmlEvent(binlogEvent) + }() + } else { + onDmlEvent(binlogEvent) + } + } +} + +func (this *EventsStreamer) InitDBConnections() (err error) { + EventsStreamerUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) + if this.db, _, err = sqlutils.GetDB(EventsStreamerUri); err != nil { + return err + } + if err := this.validateConnection(); err != nil { + return err + } + if err := this.readCurrentBinlogCoordinates(); err != nil { + return err + } + goMySQLReader, err := binlog.NewGoMySQLReader(this.migrationContext.InspectorConnectionConfig) + if err != nil { + return err + } + if err := goMySQLReader.ConnectBinlogStreamer(*this.nextBinlogCoordinates); err != nil { + return err + } + this.binlogReader = goMySQLReader + + return nil +} + +// validateConnection issues a simple can-connect to MySQL +func (this *EventsStreamer) validateConnection() error { + query := `select @@global.port` + var port int + if err := this.db.QueryRow(query).Scan(&port); err != nil { + return err + } + if port != this.connectionConfig.Key.Port { + return fmt.Errorf("Unexpected database port reported: %+v", port) + } + log.Infof("connection validated on %+v", this.connectionConfig.Key) + return nil +} + +// validateGrants verifies the user by which we're executing has necessary grants +// to do its thang. +func (this *EventsStreamer) readCurrentBinlogCoordinates() error { + query := `show /* gh-osc readCurrentBinlogCoordinates */ master status` + foundMasterStatus := false + err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { + this.nextBinlogCoordinates = &mysql.BinlogCoordinates{ + LogFile: m.GetString("File"), + LogPos: m.GetInt64("Position"), + } + foundMasterStatus = true + + return nil + }) + if err != nil { + return err + } + if !foundMasterStatus { + return fmt.Errorf("Got no results from SHOW MASTER STATUS. Bailing out") + } + log.Debugf("Streamer binlog coordinates: %+v", *this.nextBinlogCoordinates) + return nil +} + +// StreamEvents will begin streaming events. It will be blocking, so should be +// executed by a goroutine +func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { + go func() { + for binlogEntry := range this.eventsChannel { + if binlogEntry.DmlEvent != nil { + this.notifyListeners(binlogEntry.DmlEvent) + } + } + }() + return this.binlogReader.StreamEvents(canStopStreaming, this.eventsChannel) +} diff --git a/go/mysql/binlog.go b/go/mysql/binlog.go new file mode 100644 index 0000000..42cf933 --- /dev/null +++ b/go/mysql/binlog.go @@ -0,0 +1,162 @@ +/* + Copyright 2015 Shlomi Noach, courtesy Booking.com + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package mysql + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +var detachPattern *regexp.Regexp + +func init() { + detachPattern, _ = regexp.Compile(`//([^/:]+):([\d]+)`) // e.g. `//binlog.01234:567890` +} + +type BinlogType int + +const ( + BinaryLog BinlogType = iota + RelayLog +) + +// BinlogCoordinates described binary log coordinates in the form of log file & log position. +type BinlogCoordinates struct { + LogFile string + LogPos int64 + Type BinlogType +} + +// ParseInstanceKey will parse an InstanceKey from a string representation such as 127.0.0.1:3306 +func ParseBinlogCoordinates(logFileLogPos string) (*BinlogCoordinates, error) { + tokens := strings.SplitN(logFileLogPos, ":", 2) + if len(tokens) != 2 { + return nil, fmt.Errorf("ParseBinlogCoordinates: Cannot parse BinlogCoordinates from %s. Expected format is file:pos", logFileLogPos) + } + + if logPos, err := strconv.ParseInt(tokens[1], 10, 0); err != nil { + return nil, fmt.Errorf("ParseBinlogCoordinates: invalid pos: %s", tokens[1]) + } else { + return &BinlogCoordinates{LogFile: tokens[0], LogPos: logPos}, nil + } +} + +// DisplayString returns a user-friendly string representation of these coordinates +func (this *BinlogCoordinates) DisplayString() string { + return fmt.Sprintf("%s:%d", this.LogFile, this.LogPos) +} + +// String returns a user-friendly string representation of these coordinates +func (this BinlogCoordinates) String() string { + return this.DisplayString() +} + +// Equals tests equality of this corrdinate and another one. +func (this *BinlogCoordinates) Equals(other *BinlogCoordinates) bool { + if other == nil { + return false + } + return this.LogFile == other.LogFile && this.LogPos == other.LogPos && this.Type == other.Type +} + +// IsEmpty returns true if the log file is empty, unnamed +func (this *BinlogCoordinates) IsEmpty() bool { + return this.LogFile == "" +} + +// SmallerThan returns true if this coordinate is strictly smaller than the other. +func (this *BinlogCoordinates) SmallerThan(other *BinlogCoordinates) bool { + if this.LogFile < other.LogFile { + return true + } + if this.LogFile == other.LogFile && this.LogPos < other.LogPos { + return true + } + return false +} + +// SmallerThanOrEquals returns true if this coordinate is the same or equal to the other one. +// We do NOT compare the type so we can not use this.Equals() +func (this *BinlogCoordinates) SmallerThanOrEquals(other *BinlogCoordinates) bool { + if this.SmallerThan(other) { + return true + } + return this.LogFile == other.LogFile && this.LogPos == other.LogPos // No Type comparison +} + +// FileSmallerThan returns true if this coordinate's file is strictly smaller than the other's. +func (this *BinlogCoordinates) FileSmallerThan(other *BinlogCoordinates) bool { + return this.LogFile < other.LogFile +} + +// FileNumberDistance returns the numeric distance between this corrdinate's file number and the other's. +// Effectively it means "how many roatets/FLUSHes would make these coordinates's file reach the other's" +func (this *BinlogCoordinates) FileNumberDistance(other *BinlogCoordinates) int { + thisNumber, _ := this.FileNumber() + otherNumber, _ := other.FileNumber() + return otherNumber - thisNumber +} + +// FileNumber returns the numeric value of the file, and the length in characters representing the number in the filename. +// Example: FileNumber() of mysqld.log.000789 is (789, 6) +func (this *BinlogCoordinates) FileNumber() (int, int) { + tokens := strings.Split(this.LogFile, ".") + numPart := tokens[len(tokens)-1] + numLen := len(numPart) + fileNum, err := strconv.Atoi(numPart) + if err != nil { + return 0, 0 + } + return fileNum, numLen +} + +// PreviousFileCoordinatesBy guesses the filename of the previous binlog/relaylog, by given offset (number of files back) +func (this *BinlogCoordinates) PreviousFileCoordinatesBy(offset int) (BinlogCoordinates, error) { + result := BinlogCoordinates{LogPos: 0, Type: this.Type} + + fileNum, numLen := this.FileNumber() + if fileNum == 0 { + return result, errors.New("Log file number is zero, cannot detect previous file") + } + newNumStr := fmt.Sprintf("%d", (fileNum - offset)) + newNumStr = strings.Repeat("0", numLen-len(newNumStr)) + newNumStr + + tokens := strings.Split(this.LogFile, ".") + tokens[len(tokens)-1] = newNumStr + result.LogFile = strings.Join(tokens, ".") + return result, nil +} + +// PreviousFileCoordinates guesses the filename of the previous binlog/relaylog +func (this *BinlogCoordinates) PreviousFileCoordinates() (BinlogCoordinates, error) { + return this.PreviousFileCoordinatesBy(1) +} + +// PreviousFileCoordinates guesses the filename of the previous binlog/relaylog +func (this *BinlogCoordinates) NextFileCoordinates() (BinlogCoordinates, error) { + result := BinlogCoordinates{LogPos: 0, Type: this.Type} + + fileNum, numLen := this.FileNumber() + newNumStr := fmt.Sprintf("%d", (fileNum + 1)) + newNumStr = strings.Repeat("0", numLen-len(newNumStr)) + newNumStr + + tokens := strings.Split(this.LogFile, ".") + tokens[len(tokens)-1] = newNumStr + result.LogFile = strings.Join(tokens, ".") + return result, nil +} + +// FileSmallerThan returns true if this coordinate's file is strictly smaller than the other's. +func (this *BinlogCoordinates) DetachedCoordinates() (isDetached bool, detachedLogFile string, detachedLogPos string) { + detachedCoordinatesSubmatch := detachPattern.FindStringSubmatch(this.LogFile) + if len(detachedCoordinatesSubmatch) == 0 { + return false, "", "" + } + return true, detachedCoordinatesSubmatch[1], detachedCoordinatesSubmatch[2] +} diff --git a/go/mysql/connection.go b/go/mysql/connection.go index fc4115b..4d15241 100644 --- a/go/mysql/connection.go +++ b/go/mysql/connection.go @@ -5,10 +5,44 @@ package mysql +import ( + "fmt" +) + // ConnectionConfig is the minimal configuration required to connect to a MySQL server type ConnectionConfig struct { - Hostname string - Port int + Key InstanceKey User string Password string } + +func NewConnectionConfig() *ConnectionConfig { + config := &ConnectionConfig{ + Key: InstanceKey{}, + } + return config +} + +func (this *ConnectionConfig) Duplicate() *ConnectionConfig { + config := &ConnectionConfig{ + Key: InstanceKey{ + Hostname: this.Key.Hostname, + Port: this.Key.Port, + }, + User: this.User, + Password: this.Password, + } + return config +} + +func (this *ConnectionConfig) String() string { + return fmt.Sprintf("%s, user=%s", this.Key.DisplayString(), this.User) +} + +func (this *ConnectionConfig) Equals(other *ConnectionConfig) bool { + return this.Key.Equals(&other.Key) +} + +func (this *ConnectionConfig) GetDBUri(databaseName string) string { + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", this.User, this.Password, this.Key.Hostname, this.Key.Port, databaseName) +} diff --git a/go/mysql/instance_key.go b/go/mysql/instance_key.go new file mode 100644 index 0000000..06f6bd5 --- /dev/null +++ b/go/mysql/instance_key.go @@ -0,0 +1,115 @@ +/* + Copyright 2015 Shlomi Noach, courtesy Booking.com + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package mysql + +import ( + "fmt" + "strconv" + "strings" +) + +const ( + DefaultInstancePort = 3306 +) + +// InstanceKey is an instance indicator, identifued by hostname and port +type InstanceKey struct { + Hostname string + Port int +} + +const detachHint = "//" + +// ParseInstanceKey will parse an InstanceKey from a string representation such as 127.0.0.1:3306 +func NewRawInstanceKey(hostPort string) (*InstanceKey, error) { + tokens := strings.SplitN(hostPort, ":", 2) + if len(tokens) != 2 { + return nil, fmt.Errorf("Cannot parse InstanceKey from %s. Expected format is host:port", hostPort) + } + instanceKey := &InstanceKey{Hostname: tokens[0]} + var err error + if instanceKey.Port, err = strconv.Atoi(tokens[1]); err != nil { + return instanceKey, fmt.Errorf("Invalid port: %s", tokens[1]) + } + + return instanceKey, nil +} + +// ParseRawInstanceKeyLoose will parse an InstanceKey from a string representation such as 127.0.0.1:3306. +// The port part is optional; there will be no name resolve +func ParseRawInstanceKeyLoose(hostPort string) (*InstanceKey, error) { + if !strings.Contains(hostPort, ":") { + return &InstanceKey{Hostname: hostPort, Port: DefaultInstancePort}, nil + } + return NewRawInstanceKey(hostPort) +} + +// Equals tests equality between this key and another key +func (this *InstanceKey) Equals(other *InstanceKey) bool { + if other == nil { + return false + } + return this.Hostname == other.Hostname && this.Port == other.Port +} + +// SmallerThan returns true if this key is dictionary-smaller than another. +// This is used for consistent sorting/ordering; there's nothing magical about it. +func (this *InstanceKey) SmallerThan(other *InstanceKey) bool { + if this.Hostname < other.Hostname { + return true + } + if this.Hostname == other.Hostname && this.Port < other.Port { + return true + } + return false +} + +// IsDetached returns 'true' when this hostname is logically "detached" +func (this *InstanceKey) IsDetached() bool { + return strings.HasPrefix(this.Hostname, detachHint) +} + +// IsValid uses simple heuristics to see whether this key represents an actual instance +func (this *InstanceKey) IsValid() bool { + if this.Hostname == "_" { + return false + } + if this.IsDetached() { + return false + } + return len(this.Hostname) > 0 && this.Port > 0 +} + +// DetachedKey returns an instance key whose hostname is detahced: invalid, but recoverable +func (this *InstanceKey) DetachedKey() *InstanceKey { + if this.IsDetached() { + return this + } + return &InstanceKey{Hostname: fmt.Sprintf("%s%s", detachHint, this.Hostname), Port: this.Port} +} + +// ReattachedKey returns an instance key whose hostname is detahced: invalid, but recoverable +func (this *InstanceKey) ReattachedKey() *InstanceKey { + if !this.IsDetached() { + return this + } + return &InstanceKey{Hostname: this.Hostname[len(detachHint):], Port: this.Port} +} + +// StringCode returns an official string representation of this key +func (this *InstanceKey) StringCode() string { + return fmt.Sprintf("%s:%d", this.Hostname, this.Port) +} + +// DisplayString returns a user-friendly string representation of this key +func (this *InstanceKey) DisplayString() string { + return this.StringCode() +} + +// String returns a user-friendly string representation of this key +func (this InstanceKey) String() string { + return this.StringCode() +} diff --git a/go/mysql/instance_key_map.go b/go/mysql/instance_key_map.go new file mode 100644 index 0000000..85e3b6a --- /dev/null +++ b/go/mysql/instance_key_map.go @@ -0,0 +1,95 @@ +/* + Copyright 2015 Shlomi Noach, courtesy Booking.com + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package mysql + +import ( + "encoding/json" + "strings" +) + +// InstanceKeyMap is a convenience struct for listing InstanceKey-s +type InstanceKeyMap map[InstanceKey]bool + +func NewInstanceKeyMap() *InstanceKeyMap { + return &InstanceKeyMap{} +} + +// AddKey adds a single key to this map +func (this *InstanceKeyMap) AddKey(key InstanceKey) { + (*this)[key] = true +} + +// AddKeys adds all given keys to this map +func (this *InstanceKeyMap) AddKeys(keys []InstanceKey) { + for _, key := range keys { + this.AddKey(key) + } +} + +// HasKey checks if given key is within the map +func (this *InstanceKeyMap) HasKey(key InstanceKey) bool { + _, ok := (*this)[key] + return ok +} + +// GetInstanceKeys returns keys in this map in the form of an array +func (this *InstanceKeyMap) GetInstanceKeys() []InstanceKey { + res := []InstanceKey{} + for key := range *this { + res = append(res, key) + } + return res +} + +// MarshalJSON will marshal this map as JSON +func (this *InstanceKeyMap) MarshalJSON() ([]byte, error) { + return json.Marshal(this.GetInstanceKeys()) +} + +// ToJSON will marshal this map as JSON +func (this *InstanceKeyMap) ToJSON() (string, error) { + bytes, err := this.MarshalJSON() + return string(bytes), err +} + +// ToJSONString will marshal this map as JSON +func (this *InstanceKeyMap) ToJSONString() string { + s, _ := this.ToJSON() + return s +} + +// ToCommaDelimitedList will export this map in comma delimited format +func (this *InstanceKeyMap) ToCommaDelimitedList() string { + keyDisplays := []string{} + for key := range *this { + keyDisplays = append(keyDisplays, key.DisplayString()) + } + return strings.Join(keyDisplays, ",") +} + +// ReadJson unmarshalls a json into this map +func (this *InstanceKeyMap) ReadJson(jsonString string) error { + var keys []InstanceKey + err := json.Unmarshal([]byte(jsonString), &keys) + if err != nil { + return err + } + this.AddKeys(keys) + return err +} + +// ReadJson unmarshalls a json into this map +func (this *InstanceKeyMap) ReadCommaDelimitedList(list string) error { + tokens := strings.Split(list, ",") + for _, token := range tokens { + key, err := ParseRawInstanceKeyLoose(token) + if err != nil { + return err + } + this.AddKey(*key) + } + return nil +} diff --git a/go/sql/builder.go b/go/sql/builder.go index e5c5659..b8f0358 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -64,12 +64,15 @@ func BuildEqualsComparison(columns []string, values []string) (result string, er return result, nil } -func BuildRangeComparison(columns []string, values []string, comparisonSign ValueComparisonSign) (result string, err error) { +func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { if len(columns) == 0 { - return "", fmt.Errorf("Got 0 columns in GetRangeComparison") + return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") } if len(columns) != len(values) { - return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values)) + return "", explodedArgs, fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values)) + } + if len(columns) != len(args) { + return "", explodedArgs, fmt.Errorf("Got %d columns but %d args in GetEqualsComparison", len(columns), len(args)) } includeEquals := false if comparisonSign == LessThanOrEqualsComparisonSign { @@ -87,43 +90,47 @@ func BuildRangeComparison(columns []string, values []string, comparisonSign Valu value := values[i] rangeComparison, err := BuildValueComparison(column, value, comparisonSign) if err != nil { - return "", err + return "", explodedArgs, err } if len(columns[0:i]) > 0 { equalitiesComparison, err := BuildEqualsComparison(columns[0:i], values[0:i]) if err != nil { - return "", err + return "", explodedArgs, err } comparison := fmt.Sprintf("(%s AND %s)", equalitiesComparison, rangeComparison) comparisons = append(comparisons, comparison) + explodedArgs = append(explodedArgs, args[0:i]...) + explodedArgs = append(explodedArgs, args[i]) } else { comparisons = append(comparisons, rangeComparison) + explodedArgs = append(explodedArgs, args[i]) } } if includeEquals { comparison, err := BuildEqualsComparison(columns, values) if err != nil { - return "", nil + return "", explodedArgs, nil } comparisons = append(comparisons, comparison) + explodedArgs = append(explodedArgs, args...) } result = strings.Join(comparisons, " or ") result = fmt.Sprintf("(%s)", result) - return result, nil + return result, explodedArgs, nil } -func BuildRangePreparedComparison(columns []string, comparisonSign ValueComparisonSign) (result string, err error) { +func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { values := make([]string, len(columns), len(columns)) for i := range columns { values[i] = "?" } - return BuildRangeComparison(columns, values, comparisonSign) + return BuildRangeComparison(columns, values, args, comparisonSign) } -func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string) (string, error) { +func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { if len(sharedColumns) == 0 { - return "", fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery") + return "", explodedArgs, fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery") } databaseName = EscapeName(databaseName) originalTableName = EscapeName(originalTableName) @@ -134,50 +141,63 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin uniqueKey = EscapeName(uniqueKey) sharedColumnsListing := strings.Join(sharedColumns, ", ") - rangeStartComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, GreaterThanOrEqualsComparisonSign) - if err != nil { - return "", err + var minRangeComparisonSign ValueComparisonSign = GreaterThanComparisonSign + if includeRangeStartValues { + minRangeComparisonSign = GreaterThanOrEqualsComparisonSign } - rangeEndComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, LessThanOrEqualsComparisonSign) + rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, rangeStartArgs, minRangeComparisonSign) if err != nil { - return "", err + return "", explodedArgs, err } - query := fmt.Sprintf(` + explodedArgs = append(explodedArgs, rangeExplodedArgs...) + rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, rangeEndArgs, LessThanOrEqualsComparisonSign) + if err != nil { + return "", explodedArgs, err + } + explodedArgs = append(explodedArgs, rangeExplodedArgs...) + transactionalClause := "" + if transactionalTable { + transactionalClause = "lock in share mode" + } + result = fmt.Sprintf(` insert /* gh-osc %s.%s */ ignore into %s.%s (%s) (select %s from %s.%s force index (%s) - where (%s and %s) + where (%s and %s) %s ) `, databaseName, originalTableName, databaseName, ghostTableName, sharedColumnsListing, sharedColumnsListing, databaseName, originalTableName, uniqueKey, - rangeStartComparison, rangeEndComparison) - return query, nil + rangeStartComparison, rangeEndComparison, transactionalClause) + return result, explodedArgs, nil } -func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string) (string, error) { +func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { rangeStartValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) rangeEndValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { rangeStartValues[i] = "?" rangeEndValues[i] = "?" } - return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues) + return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) } -func BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName string, uniqueKeyColumns []string, chunkSize int) (string, error) { +func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, hint string) (result string, explodedArgs []interface{}, err error) { if len(uniqueKeyColumns) == 0 { - return "", fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery") + return "", explodedArgs, fmt.Errorf("Got 0 columns in BuildUniqueKeyRangeEndPreparedQuery") } databaseName = EscapeName(databaseName) - originalTableName = EscapeName(originalTableName) + tableName = EscapeName(tableName) - rangeStartComparison, err := BuildRangePreparedComparison(uniqueKeyColumns, GreaterThanComparisonSign) + rangeStartComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeStartArgs, GreaterThanComparisonSign) if err != nil { - return "", err + return "", explodedArgs, err } - rangeEndComparison, err := BuildRangePreparedComparison(uniqueKeyColumns, LessThanOrEqualsComparisonSign) + explodedArgs = append(explodedArgs, rangeExplodedArgs...) + rangeEndComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeEndArgs, LessThanOrEqualsComparisonSign) if err != nil { - return "", err + return "", explodedArgs, err } + explodedArgs = append(explodedArgs, rangeExplodedArgs...) + uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { @@ -185,8 +205,8 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName string, uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumns[i]) uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumns[i]) } - query := fmt.Sprintf(` - select /* gh-osc %s.%s */ %s + result = fmt.Sprintf(` + select /* gh-osc %s.%s %s */ %s from ( select %s @@ -200,11 +220,45 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName string, order by %s limit 1 - `, databaseName, originalTableName, strings.Join(uniqueKeyColumns, ", "), - strings.Join(uniqueKeyColumns, ", "), databaseName, originalTableName, + `, databaseName, tableName, hint, strings.Join(uniqueKeyColumns, ", "), + strings.Join(uniqueKeyColumns, ", "), databaseName, tableName, rangeStartComparison, rangeEndComparison, strings.Join(uniqueKeyColumnAscending, ", "), chunkSize, strings.Join(uniqueKeyColumnDescending, ", "), ) + return result, explodedArgs, nil +} + +func BuildUniqueKeyMinValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) { + return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "asc") +} + +func BuildUniqueKeyMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) { + return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "desc") +} + +func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, order string) (string, error) { + if len(uniqueKeyColumns) == 0 { + return "", fmt.Errorf("Got 0 columns in BuildUniqueKeyMinMaxValuesPreparedQuery") + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + + uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) + for i := range uniqueKeyColumns { + uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) + uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumns[i], order) + } + query := fmt.Sprintf(` + select /* gh-osc %s.%s */ %s + from + %s.%s + order by + %s + limit 1 + `, databaseName, tableName, strings.Join(uniqueKeyColumns, ", "), + databaseName, tableName, + strings.Join(uniqueKeyColumnOrder, ", "), + ) return query, nil } diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index 289cacf..b101bbf 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -8,6 +8,7 @@ package sql import ( "testing" + "reflect" "regexp" "strings" @@ -71,48 +72,60 @@ func TestBuildRangeComparison(t *testing.T) { { columns := []string{"c1"} values := []string{"@v1"} - comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign) + args := []interface{}{3} + comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanComparisonSign) test.S(t).ExpectNil(err) test.S(t).ExpectEquals(comparison, "((`c1` < @v1))") + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3})) } { columns := []string{"c1"} values := []string{"@v1"} - comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + args := []interface{}{3} + comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign) test.S(t).ExpectNil(err) test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or ((`c1` = @v1)))") + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3})) } { columns := []string{"c1", "c2"} values := []string{"@v1", "@v2"} - comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign) + args := []interface{}{3, 17} + comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanComparisonSign) test.S(t).ExpectNil(err) test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)))") + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17})) } { columns := []string{"c1", "c2"} values := []string{"@v1", "@v2"} - comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + args := []interface{}{3, 17} + comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign) test.S(t).ExpectNil(err) test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or ((`c1` = @v1) and (`c2` = @v2)))") + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17})) } { columns := []string{"c1", "c2", "c3"} values := []string{"@v1", "@v2", "@v3"} - comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + args := []interface{}{3, 17, 22} + comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign) test.S(t).ExpectNil(err) test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or (((`c1` = @v1) and (`c2` = @v2)) AND (`c3` < @v3)) or ((`c1` = @v1) and (`c2` = @v2) and (`c3` = @v3)))") + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 22, 3, 17, 22})) } { columns := []string{"c1"} values := []string{"@v1", "@v2"} - _, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + args := []interface{}{3, 17} + _, _, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign) test.S(t).ExpectNotNil(err) } { columns := []string{} values := []string{} - _, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + args := []interface{}{} + _, _, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign) test.S(t).ExpectNotNil(err) } } @@ -127,8 +140,10 @@ func TestBuildRangeInsertQuery(t *testing.T) { uniqueKeyColumns := []string{"id"} rangeStartValues := []string{"@v1s"} rangeEndValues := []string{"@v1e"} + rangeStartArgs := []interface{}{3} + rangeEndArgs := []interface{}{103} - query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -137,14 +152,17 @@ func TestBuildRangeInsertQuery(t *testing.T) { ) ` test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 103, 103})) } { uniqueKey := "name_position_uidx" uniqueKeyColumns := []string{"name", "position"} rangeStartValues := []string{"@v1s", "@v2s"} rangeEndValues := []string{"@v1e", "@v2e"} + rangeStartArgs := []interface{}{3, 17} + rangeEndArgs := []interface{}{103, 117} - query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -153,6 +171,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { ) ` test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117})) } } @@ -164,8 +183,10 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { { uniqueKey := "name_position_uidx" uniqueKeyColumns := []string{"name", "position"} + rangeStartArgs := []interface{}{3, 17} + rangeEndArgs := []interface{}{103, 117} - query, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns) + query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -174,6 +195,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { ) ` test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117})) } } @@ -183,11 +205,13 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { chunkSize := 500 { uniqueKeyColumns := []string{"name", "position"} + rangeStartArgs := []interface{}{3, 17} + rangeEndArgs := []interface{}{103, 117} - query, err := BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName, uniqueKeyColumns, chunkSize) + query, explodedArgs, err := BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, chunkSize, "test") test.S(t).ExpectNil(err) expected := ` - select /* gh-osc mydb.tbl */ name, position + select /* gh-osc mydb.tbl test */ name, position from ( select name, position @@ -203,5 +227,38 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { limit 1 ` test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 103, 103, 117, 103, 117})) + } +} + +func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) { + databaseName := "mydb" + originalTableName := "tbl" + uniqueKeyColumns := []string{"name", "position"} + { + query, err := BuildUniqueKeyMinValuesPreparedQuery(databaseName, originalTableName, uniqueKeyColumns) + test.S(t).ExpectNil(err) + expected := ` + select /* gh-osc mydb.tbl */ name, position + from + mydb.tbl + order by + name asc, position asc + limit 1 + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + } + { + query, err := BuildUniqueKeyMaxValuesPreparedQuery(databaseName, originalTableName, uniqueKeyColumns) + test.S(t).ExpectNil(err) + expected := ` + select /* gh-osc mydb.tbl */ name, position + from + mydb.tbl + order by + name desc, position desc + limit 1 + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) } } diff --git a/go/sql/types.go b/go/sql/types.go index 2646e2f..942bd37 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -35,7 +35,7 @@ type UniqueKey struct { HasNullable bool } -// IsPrimary cehcks if this unique key is primary +// IsPrimary checks if this unique key is primary func (this *UniqueKey) IsPrimary() bool { return this.Name == "PRIMARY" } @@ -43,3 +43,52 @@ func (this *UniqueKey) IsPrimary() bool { func (this *UniqueKey) String() string { return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable) } + +type ColumnValues struct { + abstractValues []interface{} + ValuesPointers []interface{} +} + +func NewColumnValues(length int) *ColumnValues { + result := &ColumnValues{ + abstractValues: make([]interface{}, length), + ValuesPointers: make([]interface{}, length), + } + for i := 0; i < length; i++ { + result.ValuesPointers[i] = &result.abstractValues[i] + } + + return result +} + +func ToColumnValues(abstractValues []interface{}) *ColumnValues { + result := &ColumnValues{ + abstractValues: abstractValues, + ValuesPointers: make([]interface{}, len(abstractValues)), + } + for i := 0; i < len(abstractValues); i++ { + result.ValuesPointers[i] = &result.abstractValues[i] + } + + return result +} + +func (this *ColumnValues) AbstractValues() []interface{} { + return this.abstractValues +} + +func (this *ColumnValues) StringColumn(index int) string { + val := this.AbstractValues()[index] + if ints, ok := val.([]uint8); ok { + return string(ints) + } + return fmt.Sprintf("%+v", val) +} + +func (this *ColumnValues) String() string { + stringValues := []string{} + for i := range this.AbstractValues() { + stringValues = append(stringValues, this.StringColumn(i)) + } + return strings.Join(stringValues, ",") +}