Merge pull request #14 from github/ongoing-initial-work

WIP, ongoing basics
This commit is contained in:
Shlomi Noach 2016-04-08 10:35:57 +02:00
commit 4652bb7728
18 changed files with 1971 additions and 210 deletions

View File

@ -5,8 +5,17 @@
package base
import ()
import (
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/github/gh-osc/go/mysql"
"github.com/github/gh-osc/go/sql"
)
// RowsEstimateMethod is the type of row number estimation
type RowsEstimateMethod string
const (
@ -15,18 +24,41 @@ const (
CountRowsEstimate = "CountRowsEstimate"
)
const (
maxRetries = 10
)
// MigrationContext has the general, global state of migration. It is used by
// all components throughout the migration process.
type MigrationContext struct {
DatabaseName string
OriginalTableName string
GhostTableName string
AlterStatement string
TableEngine string
CountTableRows bool
RowsEstimate int64
UsedRowsEstimateMethod RowsEstimateMethod
ChunkSize int
OriginalBinlogFormat string
OriginalBinlogRowImage string
DatabaseName string
OriginalTableName string
AlterStatement string
TableEngine string
CountTableRows bool
RowsEstimate int64
UsedRowsEstimateMethod RowsEstimateMethod
ChunkSize int64
OriginalBinlogFormat string
OriginalBinlogRowImage string
AllowedRunningOnMaster bool
InspectorConnectionConfig *mysql.ConnectionConfig
MasterConnectionConfig *mysql.ConnectionConfig
MigrationRangeMinValues *sql.ColumnValues
MigrationRangeMaxValues *sql.ColumnValues
Iteration int64
MigrationIterationRangeMinValues *sql.ColumnValues
MigrationIterationRangeMaxValues *sql.ColumnValues
UniqueKey *sql.UniqueKey
StartTime time.Time
RowCopyStartTime time.Time
CurrentLag int64
MaxLagMillisecondsThrottleThreshold int64
ThrottleFlagFile string
TotalRowsCopied int64
IsThrottled func() bool
CanStopStreaming func() bool
}
var context *MigrationContext
@ -37,15 +69,75 @@ func init() {
func newMigrationContext() *MigrationContext {
return &MigrationContext{
ChunkSize: 1000,
ChunkSize: 1000,
InspectorConnectionConfig: mysql.NewConnectionConfig(),
MasterConnectionConfig: mysql.NewConnectionConfig(),
MaxLagMillisecondsThrottleThreshold: 1000,
}
}
// GetMigrationContext
func GetMigrationContext() *MigrationContext {
return context
}
// RequiresBinlogFormatChange
// GetGhostTableName generates the name of ghost table, based on original table name
func (this *MigrationContext) GetGhostTableName() string {
return fmt.Sprintf("_%s_New", this.OriginalTableName)
}
// GetChangelogTableName generates the name of changelog table, based on original table name
func (this *MigrationContext) GetChangelogTableName() string {
return fmt.Sprintf("_%s_OSC", this.OriginalTableName)
}
// RequiresBinlogFormatChange is `true` when the original binlog format isn't `ROW`
func (this *MigrationContext) RequiresBinlogFormatChange() bool {
return this.OriginalBinlogFormat != "ROW"
}
// IsRunningOnMaster is `true` when the app connects directly to the master (typically
// it should be executed on replica and infer the master)
func (this *MigrationContext) IsRunningOnMaster() bool {
return this.InspectorConnectionConfig.Equals(this.MasterConnectionConfig)
}
// HasMigrationRange tells us whether there's a range to iterate for copying rows.
// It will be `false` if the table is initially empty
func (this *MigrationContext) HasMigrationRange() bool {
return this.MigrationRangeMinValues != nil && this.MigrationRangeMaxValues != nil
}
func (this *MigrationContext) MaxRetries() int {
return maxRetries
}
func (this *MigrationContext) IsTransactionalTable() bool {
switch strings.ToLower(this.TableEngine) {
case "innodb":
{
return true
}
case "tokudb":
{
return true
}
}
return false
}
// ElapsedTime returns time since very beginning of the process
func (this *MigrationContext) ElapsedTime() time.Duration {
return time.Now().Sub(this.StartTime)
}
// ElapsedRowCopyTime returns time since starting to copy chunks of rows
func (this *MigrationContext) ElapsedRowCopyTime() time.Duration {
return time.Now().Sub(this.RowCopyStartTime)
}
// GetTotalRowsCopied returns the accurate number of rows being copied (affected)
// This is not exactly the same as the rows being iterated via chunks, but potentially close enough
func (this *MigrationContext) GetTotalRowsCopied() int64 {
return atomic.LoadInt64(&this.TotalRowsCopied)
}

View File

@ -0,0 +1,66 @@
/*
Copyright 2016 GitHub Inc.
See https://github.com/github/gh-osc/blob/master/LICENSE
*/
package binlog
import (
"fmt"
"github.com/github/gh-osc/go/sql"
"strings"
)
type EventDML string
const (
NotDML EventDML = "NoDML"
InsertDML = "Insert"
UpdateDML = "Update"
DeleteDML = "Delete"
)
func ToEventDML(description string) EventDML {
// description can be a statement (`UPDATE my_table ...`) or a RBR event name (`UpdateRowsEventV2`)
description = strings.TrimSpace(strings.Split(description, " ")[0])
switch strings.ToLower(description) {
case "insert":
return InsertDML
case "update":
return UpdateDML
case "delete":
return DeleteDML
}
if strings.HasPrefix(description, "WriteRows") {
return InsertDML
}
if strings.HasPrefix(description, "UpdateRows") {
return UpdateDML
}
if strings.HasPrefix(description, "DeleteRows") {
return DeleteDML
}
return NotDML
}
// BinlogDMLEvent is a binary log rows (DML) event entry, with data
type BinlogDMLEvent struct {
DatabaseName string
TableName string
DML EventDML
WhereColumnValues *sql.ColumnValues
NewColumnValues *sql.ColumnValues
}
func NewBinlogDMLEvent(databaseName, tableName string, dml EventDML) *BinlogDMLEvent {
event := &BinlogDMLEvent{
DatabaseName: databaseName,
TableName: tableName,
DML: dml,
}
return event
}
func (this *BinlogDMLEvent) String() string {
return fmt.Sprintf("[%+v on %s:%s]", this.DML, this.DatabaseName, this.TableName)
}

View File

@ -5,27 +5,43 @@
package binlog
import (
"fmt"
"github.com/github/gh-osc/go/mysql"
)
// BinlogEntry describes an entry in the binary log
type BinlogEntry struct {
LogPos uint64
EndLogPos uint64
StatementType string // INSERT, UPDATE, DELETE
DatabaseName string
TableName string
PositionalColumns map[uint64]interface{}
Coordinates mysql.BinlogCoordinates
EndLogPos uint64
DmlEvent *BinlogDMLEvent
}
// NewBinlogEntry creates an empty, ready to go BinlogEntry object
func NewBinlogEntry() *BinlogEntry {
binlogEntry := &BinlogEntry{}
binlogEntry.PositionalColumns = make(map[uint64]interface{})
func NewBinlogEntry(logFile string, logPos uint64) *BinlogEntry {
binlogEntry := &BinlogEntry{
Coordinates: mysql.BinlogCoordinates{LogFile: logFile, LogPos: int64(logPos)},
}
return binlogEntry
}
// NewBinlogEntry creates an empty, ready to go BinlogEntry object
func NewBinlogEntryAt(coordinates mysql.BinlogCoordinates) *BinlogEntry {
binlogEntry := &BinlogEntry{
Coordinates: coordinates,
}
return binlogEntry
}
// Duplicate creates and returns a new binlog entry, with some of the attributes pre-assigned
func (this *BinlogEntry) Duplicate() *BinlogEntry {
binlogEntry := NewBinlogEntry()
binlogEntry.LogPos = this.LogPos
binlogEntry := NewBinlogEntry(this.Coordinates.LogFile, uint64(this.Coordinates.LogPos))
binlogEntry.EndLogPos = this.EndLogPos
return binlogEntry
}
// Duplicate creates and returns a new binlog entry, with some of the attributes pre-assigned
func (this *BinlogEntry) String() string {
return fmt.Sprintf("[BinlogEntry at %+v; dml:%+v]", this.Coordinates, this.DmlEvent)
}

View File

@ -8,5 +8,5 @@ package binlog
// BinlogReader is a general interface whose implementations can choose their methods of reading
// a binary log file and parsing it into binlog entries
type BinlogReader interface {
ReadEntries(logFile string, startPos uint64, stopPos uint64) (entries [](*BinlogEntry), err error)
StreamEvents(canStopStreaming func() bool, entriesChannel chan<- *BinlogEntry) error
}

View File

@ -7,11 +7,10 @@ package binlog
import (
"fmt"
"os"
"reflect"
"strings"
"github.com/github/gh-osc/go/mysql"
"github.com/github/gh-osc/go/sql"
"github.com/outbrain/golib/log"
gomysql "github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/replication"
@ -24,18 +23,24 @@ const (
)
type GoMySQLReader struct {
connectionConfig *mysql.ConnectionConfig
binlogSyncer *replication.BinlogSyncer
connectionConfig *mysql.ConnectionConfig
binlogSyncer *replication.BinlogSyncer
binlogStreamer *replication.BinlogStreamer
tableMap map[uint64]string
currentCoordinates mysql.BinlogCoordinates
}
func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *GoMySQLReader, err error) {
binlogReader = &GoMySQLReader{
connectionConfig: connectionConfig,
connectionConfig: connectionConfig,
tableMap: make(map[uint64]string),
currentCoordinates: mysql.BinlogCoordinates{},
binlogStreamer: nil,
}
binlogReader.binlogSyncer = replication.NewBinlogSyncer(serverId, "mysql")
// Register slave, the MySQL master is at 127.0.0.1:3306, with user root and an empty password
err = binlogReader.binlogSyncer.RegisterSlave(connectionConfig.Hostname, uint16(connectionConfig.Port), connectionConfig.User, connectionConfig.Password)
err = binlogReader.binlogSyncer.RegisterSlave(connectionConfig.Key.Hostname, uint16(connectionConfig.Key.Port), connectionConfig.User, connectionConfig.Password)
if err != nil {
return binlogReader, err
}
@ -43,57 +48,75 @@ func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *G
return binlogReader, err
}
func (this *GoMySQLReader) isDMLEvent(event *replication.BinlogEvent) bool {
eventType := event.Header.EventType.String()
if strings.HasPrefix(eventType, "WriteRows") {
return true
}
if strings.HasPrefix(eventType, "UpdateRows") {
return true
}
if strings.HasPrefix(eventType, "DeleteRows") {
return true
}
return false
}
// ReadEntries will read binlog entries from parsed text output of `mysqlbinlog` utility
func (this *GoMySQLReader) ReadEntries(logFile string, startPos uint64, stopPos uint64) (entries [](*BinlogEntry), err error) {
// ConnectBinlogStreamer
func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordinates) (err error) {
this.currentCoordinates = coordinates
// Start sync with sepcified binlog file and position
streamer, err := this.binlogSyncer.StartSync(gomysql.Position{logFile, uint32(startPos)})
if err != nil {
return entries, err
}
this.binlogStreamer, err = this.binlogSyncer.StartSync(gomysql.Position{coordinates.LogFile, uint32(coordinates.LogPos)})
for {
ev, err := streamer.GetEvent()
if err != nil {
return entries, err
}
if rowsEvent, ok := ev.Event.(*replication.RowsEvent); ok {
if true {
fmt.Println(ev.Header.EventType)
fmt.Println(len(rowsEvent.Rows))
for _, rows := range rowsEvent.Rows {
for j, d := range rows {
if _, ok := d.([]byte); ok {
fmt.Print(fmt.Sprintf("yesbin %d:%q, %+v\n", j, d, reflect.TypeOf(d)))
} else {
fmt.Print(fmt.Sprintf("notbin %d:%#v, %+v\n", j, d, reflect.TypeOf(d)))
}
}
fmt.Println("---")
}
} else {
ev.Dump(os.Stdout)
}
// TODO : convert to entries
// need to parse multi-row entries
// insert & delete are just one row per db orw
// update: where-row_>values-row, repeating
}
}
log.Debugf("done")
return entries, err
return err
}
// StreamEvents
func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesChannel chan<- *BinlogEntry) error {
for {
if canStopStreaming() {
break
}
ev, err := this.binlogStreamer.GetEvent()
if err != nil {
return err
}
this.currentCoordinates.LogPos = int64(ev.Header.LogPos)
if rotateEvent, ok := ev.Event.(*replication.RotateEvent); ok {
this.currentCoordinates.LogFile = string(rotateEvent.NextLogName)
log.Infof("rotate to next log name: %s", rotateEvent.NextLogName)
} else if tableMapEvent, ok := ev.Event.(*replication.TableMapEvent); ok {
// Actually not being used, since Table is available in RowsEvent.
// Keeping this here in case I'm wrong about this. Sometime in the near
// future I should remove this.
this.tableMap[tableMapEvent.TableID] = string(tableMapEvent.Table)
} else if rowsEvent, ok := ev.Event.(*replication.RowsEvent); ok {
dml := ToEventDML(ev.Header.EventType.String())
if dml == NotDML {
return fmt.Errorf("Unknown DML type: %s", ev.Header.EventType.String())
}
for i, row := range rowsEvent.Rows {
if dml == UpdateDML && i%2 == 1 {
// An update has two rows (WHERE+SET)
// We do both at the same time
continue
}
binlogEntry := NewBinlogEntryAt(this.currentCoordinates)
binlogEntry.DmlEvent = NewBinlogDMLEvent(
string(rowsEvent.Table.Schema),
string(rowsEvent.Table.Table),
dml,
)
switch dml {
case InsertDML:
{
binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(row)
}
case UpdateDML:
{
binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row)
binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(rowsEvent.Rows[i+1])
}
case DeleteDML:
{
binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row)
}
}
// The channel will do the throttling. Whoever is reding from the channel
// decides whether action is taken sycnhronously (meaning we wait before
// next iteration) or asynchronously (we keep pushing more events)
// In reality, reads will be synchronous
entriesChannel <- binlogEntry
}
}
}
log.Debugf("done streaming events")
return nil
}

View File

@ -12,7 +12,7 @@ import (
"path"
"regexp"
"strconv"
"strings"
// "strings"
"github.com/github/gh-osc/go/os"
"github.com/outbrain/golib/log"
@ -78,7 +78,7 @@ func (this *MySQLBinlogReader) ReadEntries(logFile string, startPos uint64, stop
return entries, log.Errore(err)
}
chunkEntries, err := parseEntries(bufio.NewScanner(bytes.NewReader(entriesBytes)))
chunkEntries, err := parseEntries(bufio.NewScanner(bytes.NewReader(entriesBytes)), logFile)
if err != nil {
return entries, log.Errore(err)
}
@ -103,41 +103,38 @@ func searchForStartPosOrStatement(scanner *bufio.Scanner, binlogEntry *BinlogEnt
return InvalidState, binlogEntry, fmt.Errorf("Expected startLogPos %+v to equal previous endLogPos %+v", startLogPos, previousEndLogPos)
}
nextBinlogEntry = binlogEntry
if binlogEntry.LogPos != 0 && binlogEntry.StatementType != "" {
if binlogEntry.Coordinates.LogPos != 0 && binlogEntry.DmlEvent != nil {
// Current entry is already a true entry, with startpos and with statement
nextBinlogEntry = NewBinlogEntry()
nextBinlogEntry = NewBinlogEntry(binlogEntry.Coordinates.LogFile, startLogPos)
}
nextBinlogEntry.LogPos = startLogPos
return ExpectEndLogPosState, nextBinlogEntry, nil
}
onStatementEntry := func(submatch []string) (BinlogEntryState, *BinlogEntry, error) {
nextBinlogEntry = binlogEntry
if binlogEntry.LogPos != 0 && binlogEntry.StatementType != "" {
if binlogEntry.Coordinates.LogPos != 0 && binlogEntry.DmlEvent != nil {
// Current entry is already a true entry, with startpos and with statement
nextBinlogEntry = binlogEntry.Duplicate()
}
nextBinlogEntry.StatementType = strings.Split(submatch[1], " ")[0]
nextBinlogEntry.DatabaseName = submatch[2]
nextBinlogEntry.TableName = submatch[3]
nextBinlogEntry.DmlEvent = NewBinlogDMLEvent(submatch[2], submatch[3], ToEventDML(submatch[1]))
return ExpectTokenState, nextBinlogEntry, nil
}
onPositionalColumn := func(submatch []string) (BinlogEntryState, *BinlogEntry, error) {
columnIndex, _ := strconv.ParseUint(submatch[1], 10, 64)
if _, found := binlogEntry.PositionalColumns[columnIndex]; found {
return InvalidState, binlogEntry, fmt.Errorf("Positional column %+v found more than once in %+v, statement=%+v", columnIndex, binlogEntry.LogPos, binlogEntry.StatementType)
}
columnValue := submatch[2]
columnValue = strings.TrimPrefix(columnValue, "'")
columnValue = strings.TrimSuffix(columnValue, "'")
binlogEntry.PositionalColumns[columnIndex] = columnValue
// Defuncting the following:
return SearchForStartPosOrStatementState, binlogEntry, nil
}
// onPositionalColumn := func(submatch []string) (BinlogEntryState, *BinlogEntry, error) {
// columnIndex, _ := strconv.ParseUint(submatch[1], 10, 64)
// if _, found := binlogEntry.PositionalColumns[columnIndex]; found {
// return InvalidState, binlogEntry, fmt.Errorf("Positional column %+v found more than once in %+v, statement=%+v", columnIndex, binlogEntry.LogPos, binlogEntry.DmlEvent.DML)
// }
// columnValue := submatch[2]
// columnValue = strings.TrimPrefix(columnValue, "'")
// columnValue = strings.TrimSuffix(columnValue, "'")
// binlogEntry.PositionalColumns[columnIndex] = columnValue
//
// return SearchForStartPosOrStatementState, binlogEntry, nil
// }
line := scanner.Text()
if submatch := startEntryRegexp.FindStringSubmatch(line); len(submatch) > 1 {
@ -150,7 +147,7 @@ func searchForStartPosOrStatement(scanner *bufio.Scanner, binlogEntry *BinlogEnt
return onStatementEntry(submatch)
}
if submatch := positionalColumnRegexp.FindStringSubmatch(line); len(submatch) > 1 {
return onPositionalColumn(submatch)
// Defuncting return onPositionalColumn(submatch)
}
// Haven't found a match
return SearchForStartPosOrStatementState, binlogEntry, nil
@ -165,7 +162,7 @@ func expectEndLogPos(scanner *bufio.Scanner, binlogEntry *BinlogEntry) (nextStat
binlogEntry.EndLogPos, _ = strconv.ParseUint(submatch[1], 10, 64)
return SearchForStartPosOrStatementState, nil
}
return InvalidState, fmt.Errorf("Expected to find end_log_pos following pos %+v", binlogEntry.LogPos)
return InvalidState, fmt.Errorf("Expected to find end_log_pos following pos %+v", binlogEntry.Coordinates.LogPos)
}
// automaton step: a not-strictly-required but good-to-have-around validation that
@ -175,26 +172,26 @@ func expectToken(scanner *bufio.Scanner, binlogEntry *BinlogEntry) (nextState Bi
if submatch := tokenRegxp.FindStringSubmatch(line); len(submatch) > 1 {
return SearchForStartPosOrStatementState, nil
}
return InvalidState, fmt.Errorf("Expected to find token following pos %+v", binlogEntry.LogPos)
return InvalidState, fmt.Errorf("Expected to find token following pos %+v", binlogEntry.Coordinates.LogPos)
}
// parseEntries will parse output of `mysqlbinlog --verbose --base64-output=DECODE-ROWS`
// It issues an automaton / state machine to do its thang.
func parseEntries(scanner *bufio.Scanner) (entries [](*BinlogEntry), err error) {
binlogEntry := NewBinlogEntry()
func parseEntries(scanner *bufio.Scanner, logFile string) (entries [](*BinlogEntry), err error) {
binlogEntry := NewBinlogEntry(logFile, 0)
var state BinlogEntryState = SearchForStartPosOrStatementState
var endLogPos uint64
appendBinlogEntry := func() {
if binlogEntry.LogPos == 0 {
if binlogEntry.Coordinates.LogPos == 0 {
return
}
if binlogEntry.StatementType == "" {
if binlogEntry.DmlEvent == nil {
return
}
entries = append(entries, binlogEntry)
log.Debugf("entry: %+v", *binlogEntry)
fmt.Println(fmt.Sprintf("%s `%s`.`%s`", binlogEntry.StatementType, binlogEntry.DatabaseName, binlogEntry.TableName))
fmt.Println(fmt.Sprintf("%s `%s`.`%s`", binlogEntry.DmlEvent.DML, binlogEntry.DmlEvent.DatabaseName, binlogEntry.DmlEvent.TableName))
}
for scanner.Scan() {
switch state {

View File

@ -11,15 +11,12 @@ import (
"os"
"github.com/github/gh-osc/go/base"
"github.com/github/gh-osc/go/binlog"
"github.com/github/gh-osc/go/logic"
"github.com/github/gh-osc/go/mysql"
"github.com/outbrain/golib/log"
)
// main is the application's entry point. It will either spawn a CLI or HTTP itnerfaces.
func main() {
var connectionConfig mysql.ConnectionConfig
migrationContext := base.GetMigrationContext()
// mysqlBasedir := flag.String("mysql-basedir", "", "the --basedir config for MySQL (auto-detected if not given)")
@ -27,15 +24,19 @@ func main() {
internalExperiment := flag.Bool("internal-experiment", false, "issue an internal experiment")
binlogFile := flag.String("binlog-file", "", "Name of binary log file")
flag.StringVar(&connectionConfig.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)")
flag.IntVar(&connectionConfig.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
flag.StringVar(&connectionConfig.User, "user", "root", "MySQL user")
flag.StringVar(&connectionConfig.Password, "password", "", "MySQL password")
flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)")
flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
flag.StringVar(&migrationContext.InspectorConnectionConfig.User, "user", "root", "MySQL user")
flag.StringVar(&migrationContext.InspectorConnectionConfig.Password, "password", "", "MySQL password")
flag.StringVar(&migrationContext.DatabaseName, "database", "", "database name (mandatory)")
flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)")
flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory)")
flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)")
flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica")
flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration")
flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists")
quiet := flag.Bool("quiet", false, "quiet")
verbose := flag.Bool("verbose", false, "verbose")
@ -78,19 +79,20 @@ func main() {
log.Info("starting gh-osc")
if *internalExperiment {
log.Debug("starting experiment")
var binlogReader binlog.BinlogReader
var err error
log.Debug("starting experiment with %+v", *binlogFile)
//binlogReader = binlog.NewMySQLBinlogReader(*mysqlBasedir, *mysqlDatadir)
binlogReader, err = binlog.NewGoMySQLReader(&connectionConfig)
if err != nil {
log.Fatale(err)
}
binlogReader.ReadEntries(*binlogFile, 0, 0)
return
// binlogReader, err := binlog.NewGoMySQLReader(migrationContext.InspectorConnectionConfig)
// if err != nil {
// log.Fatale(err)
// }
// if err := binlogReader.ConnectBinlogStreamer(mysql.BinlogCoordinates{LogFile: *binlogFile, LogPos: 0}); err != nil {
// log.Fatale(err)
// }
// binlogReader.StreamEvents(func() bool { return false })
// return
}
migrator := logic.NewMigrator(&connectionConfig)
migrator := logic.NewMigrator()
err := migrator.Migrate()
if err != nil {
log.Fatale(err)

414
go/logic/applier.go Normal file
View File

@ -0,0 +1,414 @@
/*
Copyright 2016 GitHub Inc.
See https://github.com/github/gh-osc/blob/master/LICENSE
*/
package logic
import (
gosql "database/sql"
"fmt"
"sync/atomic"
"time"
"github.com/github/gh-osc/go/base"
"github.com/github/gh-osc/go/mysql"
"github.com/github/gh-osc/go/sql"
"github.com/outbrain/golib/log"
"github.com/outbrain/golib/sqlutils"
)
const (
heartbeatIntervalSeconds = 1
)
// Applier reads data from the read-MySQL-server (typically a replica, but can be the master)
// It is used for gaining initial status and structure, and later also follow up on progress and changelog
type Applier struct {
connectionConfig *mysql.ConnectionConfig
db *gosql.DB
migrationContext *base.MigrationContext
}
func NewApplier() *Applier {
return &Applier{
connectionConfig: base.GetMigrationContext().MasterConnectionConfig,
migrationContext: base.GetMigrationContext(),
}
}
func (this *Applier) InitDBConnections() (err error) {
ApplierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
if this.db, _, err = sqlutils.GetDB(ApplierUri); err != nil {
return err
}
if err := this.validateConnection(); err != nil {
return err
}
return nil
}
// validateConnection issues a simple can-connect to MySQL
func (this *Applier) validateConnection() error {
query := `select @@global.port`
var port int
if err := this.db.QueryRow(query).Scan(&port); err != nil {
return err
}
if port != this.connectionConfig.Key.Port {
return fmt.Errorf("Unexpected database port reported: %+v", port)
}
log.Infof("connection validated on %+v", this.connectionConfig.Key)
return nil
}
// CreateGhostTable creates the ghost table on the master
func (this *Applier) CreateGhostTable() error {
query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetGhostTableName()),
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.OriginalTableName),
)
log.Infof("Creating ghost table %s.%s",
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetGhostTableName()),
)
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err
}
log.Infof("Ghost table created")
return nil
}
// CreateGhostTable creates the ghost table on the master
func (this *Applier) AlterGhost() error {
query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetGhostTableName()),
this.migrationContext.AlterStatement,
)
log.Infof("Altering ghost table %s.%s",
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetGhostTableName()),
)
log.Debugf("ALTER statement: %s", query)
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err
}
log.Infof("Ghost table altered")
return nil
}
// CreateChangelogTable creates the changelog table on the master
func (this *Applier) CreateChangelogTable() error {
query := fmt.Sprintf(`create /* gh-osc */ table %s.%s (
id int auto_increment,
last_update timestamp not null DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
hint varchar(64) charset ascii not null,
value varchar(64) charset ascii not null,
primary key(id),
unique key hint_uidx(hint)
) auto_increment=2
`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
log.Infof("Creating changelog table %s.%s",
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err
}
log.Infof("Changelog table created")
return nil
}
// DropChangelogTable drops the changelog table on the master
func (this *Applier) DropChangelogTable() error {
query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
log.Infof("Droppping changelog table %s.%s",
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err
}
log.Infof("Changelog table dropped")
return nil
}
// WriteChangelog writes a value to the changelog table.
// It returns the hint as given, for convenience
func (this *Applier) WriteChangelog(hint, value string) (string, error) {
query := fmt.Sprintf(`
insert /* gh-osc */ into %s.%s
(id, hint, value)
values
(NULL, ?, ?)
on duplicate key update
last_update=NOW(),
value=VALUES(value)
`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
_, err := sqlutils.Exec(this.db, query, hint, value)
return hint, err
}
// InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table.
// This is done asynchronously
func (this *Applier) InitiateHeartbeat() {
go func() {
numSuccessiveFailures := 0
query := fmt.Sprintf(`
insert /* gh-osc */ into %s.%s
(id, hint, value)
values
(1, 'heartbeat', ?)
on duplicate key update
last_update=NOW(),
value=VALUES(value)
`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
injectHeartbeat := func() error {
if _, err := sqlutils.ExecNoPrepare(this.db, query, time.Now().Format(time.RFC3339)); err != nil {
numSuccessiveFailures++
if numSuccessiveFailures > this.migrationContext.MaxRetries() {
return log.Errore(err)
}
} else {
numSuccessiveFailures = 0
}
return nil
}
injectHeartbeat()
heartbeatTick := time.Tick(time.Duration(heartbeatIntervalSeconds) * time.Second)
for range heartbeatTick {
// Generally speaking, we would issue a goroutine, but I'd actually rather
// have this blocked rather than spam the master in the event something
// goes wrong
if err := injectHeartbeat(); err != nil {
return
}
}
}()
}
// ReadMigrationMinValues
func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error {
log.Debugf("Reading migration range according to key: %s", uniqueKey.Name)
query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns)
if err != nil {
return err
}
rows, err := this.db.Query(query)
if err != nil {
return err
}
for rows.Next() {
this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(len(uniqueKey.Columns))
if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil {
return err
}
}
log.Infof("Migration min values: [%s]", this.migrationContext.MigrationRangeMinValues)
return err
}
// ReadMigrationMinValues
func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error {
log.Debugf("Reading migration range according to key: %s", uniqueKey.Name)
query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns)
if err != nil {
return err
}
rows, err := this.db.Query(query)
if err != nil {
return err
}
for rows.Next() {
this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(len(uniqueKey.Columns))
if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil {
return err
}
}
log.Infof("Migration max values: [%s]", this.migrationContext.MigrationRangeMaxValues)
return err
}
func (this *Applier) ReadMigrationRangeValues() error {
if err := this.ReadMigrationMinValues(this.migrationContext.UniqueKey); err != nil {
return err
}
if err := this.ReadMigrationMaxValues(this.migrationContext.UniqueKey); err != nil {
return err
}
return nil
}
// __unused_IterationIsComplete lets us know when the copy-iteration phase is complete, i.e.
// we've exhausted all rows
func (this *Applier) __unused_IterationIsComplete() (bool, error) {
if !this.migrationContext.HasMigrationRange() {
return false, nil
}
if this.migrationContext.MigrationIterationRangeMinValues == nil {
return false, nil
}
args := sqlutils.Args()
compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign)
if err != nil {
return false, err
}
args = append(args, explodedArgs...)
compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign)
if err != nil {
return false, err
}
args = append(args, explodedArgs...)
query := fmt.Sprintf(`
select /* gh-osc IterationIsComplete */ 1
from %s.%s
where (%s) and (%s)
limit 1
`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.OriginalTableName),
compareWithIterationRangeStart,
compareWithRangeEnd,
)
moreRowsFound := false
err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error {
moreRowsFound = true
return nil
}, args...)
if err != nil {
return false, err
}
return !moreRowsFound, nil
}
// CalculateNextIterationRangeEndValues reads the next-iteration-range-end unique key values,
// which will be used for copying the next chunk of rows. Ir returns "false" if there is
// no further chunk to work through, i.e. we're past the last chunk and are done with
// itrating the range (and this done with copying row chunks)
func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange bool, err error) {
this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationIterationRangeMaxValues
if this.migrationContext.MigrationIterationRangeMinValues == nil {
this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationRangeMinValues
}
query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery(
this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName,
this.migrationContext.UniqueKey.Columns,
this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(),
this.migrationContext.MigrationRangeMaxValues.AbstractValues(),
this.migrationContext.ChunkSize,
fmt.Sprintf("iteration:%d", this.migrationContext.Iteration),
)
if err != nil {
return hasFurtherRange, err
}
rows, err := this.db.Query(query, explodedArgs...)
if err != nil {
return hasFurtherRange, err
}
iterationRangeMaxValues := sql.NewColumnValues(len(this.migrationContext.UniqueKey.Columns))
for rows.Next() {
if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil {
return hasFurtherRange, err
}
hasFurtherRange = true
}
if !hasFurtherRange {
log.Debugf("Iteration complete: cannot find iteration end")
return hasFurtherRange, nil
}
this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues
log.Debugf(
"column values: [%s]..[%s]; iteration: %d; chunk-size: %d",
this.migrationContext.MigrationIterationRangeMinValues,
this.migrationContext.MigrationIterationRangeMaxValues,
this.migrationContext.Iteration,
this.migrationContext.ChunkSize,
)
return hasFurtherRange, nil
}
func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected int64, duration time.Duration, err error) {
startTime := time.Now()
chunkSize = atomic.LoadInt64(&this.migrationContext.ChunkSize)
query, explodedArgs, err := sql.BuildRangeInsertPreparedQuery(
this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName,
this.migrationContext.GetGhostTableName(),
this.migrationContext.UniqueKey.Columns,
this.migrationContext.UniqueKey.Name,
this.migrationContext.UniqueKey.Columns,
this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(),
this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(),
this.migrationContext.Iteration == 0,
this.migrationContext.IsTransactionalTable(),
)
if err != nil {
return chunkSize, rowsAffected, duration, err
}
sqlResult, err := sqlutils.Exec(this.db, query, explodedArgs...)
if err != nil {
return chunkSize, rowsAffected, duration, err
}
rowsAffected, _ = sqlResult.RowsAffected()
duration = time.Now().Sub(startTime)
this.WriteChangelog(
fmt.Sprintf("copy iteration %d", this.migrationContext.Iteration),
fmt.Sprintf("chunk: %d; affected: %d; duration: %d", chunkSize, rowsAffected, duration),
)
log.Debugf(
"Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d",
this.migrationContext.MigrationIterationRangeMinValues,
this.migrationContext.MigrationIterationRangeMaxValues,
this.migrationContext.Iteration,
chunkSize)
return chunkSize, rowsAffected, duration, nil
}
// LockTables
func (this *Applier) LockTables() error {
query := fmt.Sprintf(`lock /* gh-osc */ tables %s.%s write, %s.%s write, %s.%s write`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.OriginalTableName),
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetGhostTableName()),
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
log.Infof("Locking tables")
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err
}
log.Infof("Tables locked")
return nil
}
// UnlockTables
func (this *Applier) UnlockTables() error {
query := `unlock /* gh-osc */ tables`
log.Infof("Unlocking tables")
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err
}
log.Infof("Tables unlocked")
return nil
}

View File

@ -26,15 +26,15 @@ type Inspector struct {
migrationContext *base.MigrationContext
}
func NewInspector(connectionConfig *mysql.ConnectionConfig) *Inspector {
func NewInspector() *Inspector {
return &Inspector{
connectionConfig: connectionConfig,
connectionConfig: base.GetMigrationContext().InspectorConnectionConfig,
migrationContext: base.GetMigrationContext(),
}
}
func (this *Inspector) InitDBConnections() (err error) {
inspectorUri := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", this.connectionConfig.User, this.connectionConfig.Password, this.connectionConfig.Hostname, this.connectionConfig.Port, this.migrationContext.DatabaseName)
inspectorUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
if this.db, _, err = sqlutils.GetDB(inspectorUri); err != nil {
return err
}
@ -47,9 +47,16 @@ func (this *Inspector) InitDBConnections() (err error) {
if err := this.validateBinlogs(); err != nil {
return err
}
return nil
}
func (this *Inspector) ValidateOriginalTable() (err error) {
if err := this.validateTable(); err != nil {
return err
}
if err := this.validateTableForeignKeys(); err != nil {
return err
}
if this.migrationContext.CountTableRows {
if err := this.countTableRows(); err != nil {
return err
@ -59,32 +66,31 @@ func (this *Inspector) InitDBConnections() (err error) {
return err
}
}
return nil
}
func (this *Inspector) InspectTables() (err error) {
uniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName)
func (this *Inspector) InspectOriginalTable() (uniqueKeys [](*sql.UniqueKey), err error) {
uniqueKeys, err = this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName)
if err != nil {
return err
return uniqueKeys, err
}
if len(uniqueKeys) == 0 {
return fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out")
return uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out")
}
return nil
return uniqueKeys, err
}
// validateConnection issues a simple can-connect to MySQL
func (this *Inspector) validateConnection() error {
query := `select @@port`
query := `select @@global.port`
var port int
if err := this.db.QueryRow(query).Scan(&port); err != nil {
return err
}
if port != this.connectionConfig.Port {
if port != this.connectionConfig.Key.Port {
return fmt.Errorf("Unexpected database port reported: %+v", port)
}
log.Infof("connection validated on port %+v", port)
log.Infof("connection validated on %+v", this.connectionConfig.Key)
return nil
}
@ -116,7 +122,7 @@ func (this *Inspector) validateGrants() error {
return nil
})
if err != nil {
return log.Errore(err)
return err
}
if foundAll {
@ -130,7 +136,7 @@ func (this *Inspector) validateGrants() error {
return log.Errorf("User has insufficient privileges for migration.")
}
// validateConnection issues a simple can-connect to MySQL
// validateBinlogs checks that binary log configuration is good to go
func (this *Inspector) validateBinlogs() error {
query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format`
var hasBinaryLogs, logSlaveUpdates bool
@ -138,10 +144,10 @@ func (this *Inspector) validateBinlogs() error {
return err
}
if !hasBinaryLogs {
return fmt.Errorf("%s:%d must have binary logs enabled", this.connectionConfig.Hostname, this.connectionConfig.Port)
return fmt.Errorf("%s:%d must have binary logs enabled", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port)
}
if !logSlaveUpdates {
return fmt.Errorf("%s:%d must have log_slave_updates enabled", this.connectionConfig.Hostname, this.connectionConfig.Port)
return fmt.Errorf("%s:%d must have log_slave_updates enabled", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port)
}
if this.migrationContext.RequiresBinlogFormatChange() {
query := fmt.Sprintf(`show /* gh-osc */ slave hosts`)
@ -151,12 +157,12 @@ func (this *Inspector) validateBinlogs() error {
return nil
})
if err != nil {
return log.Errore(err)
return err
}
if countReplicas > 0 {
return fmt.Errorf("%s:%d has %s binlog_format, but I'm too scared to change it to ROW because it has replicas. Bailing out", this.connectionConfig.Hostname, this.connectionConfig.Port, this.migrationContext.OriginalBinlogFormat)
return fmt.Errorf("%s:%d has %s binlog_format, but I'm too scared to change it to ROW because it has replicas. Bailing out", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogFormat)
}
log.Infof("%s:%d has %s binlog_format. I will change it to ROW for the duration of this migration.", this.connectionConfig.Hostname, this.connectionConfig.Port, this.migrationContext.OriginalBinlogFormat)
log.Infof("%s:%d has %s binlog_format. I will change it to ROW for the duration of this migration.", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogFormat)
}
query = `select @@global.binlog_row_image`
if err := this.db.QueryRow(query).Scan(&this.migrationContext.OriginalBinlogRowImage); err != nil {
@ -164,7 +170,7 @@ func (this *Inspector) validateBinlogs() error {
this.migrationContext.OriginalBinlogRowImage = ""
}
log.Infof("binary logs validated on %s:%d", this.connectionConfig.Hostname, this.connectionConfig.Port)
log.Infof("binary logs validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port)
return nil
}
@ -185,7 +191,7 @@ func (this *Inspector) validateTable() error {
return nil
})
if err != nil {
return log.Errore(err)
return err
}
if !tableFound {
return log.Errorf("Cannot find table %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName))
@ -195,6 +201,37 @@ func (this *Inspector) validateTable() error {
return nil
}
func (this *Inspector) validateTableForeignKeys() error {
query := `
SELECT COUNT(*) AS num_foreign_keys
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE
REFERENCED_TABLE_NAME IS NOT NULL
AND ((TABLE_SCHEMA=? AND TABLE_NAME=?)
OR (REFERENCED_TABLE_SCHEMA=? AND REFERENCED_TABLE_NAME=?)
)
`
numForeignKeys := 0
err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error {
numForeignKeys = rowMap.GetInt("num_foreign_keys")
return nil
},
this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName,
this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName,
)
if err != nil {
return err
}
if numForeignKeys > 0 {
return log.Errorf("Found %d foreign keys on %s.%s. Foreign keys are not supported. Bailing out", numForeignKeys, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName))
}
log.Debugf("Validated no foreign keys exist on table")
return nil
}
func (this *Inspector) estimateTableRowsViaExplain() error {
query := fmt.Sprintf(`explain select /* gh-osc */ * from %s.%s where 1=1`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName))
@ -207,7 +244,7 @@ func (this *Inspector) estimateTableRowsViaExplain() error {
return nil
})
if err != nil {
return log.Errore(err)
return err
}
if !outputFound {
return log.Errorf("Cannot run EXPLAIN on %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName))
@ -227,6 +264,29 @@ func (this *Inspector) countTableRows() error {
return nil
}
func (this *Inspector) getTableColumns(databaseName, tableName string) (columns sql.ColumnList, err error) {
query := fmt.Sprintf(`
show columns from %s.%s
`,
sql.EscapeName(databaseName),
sql.EscapeName(tableName),
)
err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error {
columns = append(columns, rowMap.GetString("Field"))
return nil
})
if err != nil {
return columns, err
}
if len(columns) == 0 {
return columns, log.Errorf("Found 0 columns on %s.%s. Bailing out",
sql.EscapeName(databaseName),
sql.EscapeName(tableName),
)
}
return columns, nil
}
// getCandidateUniqueKeys investigates a table and returns the list of unique keys
// candidate for chunking
func (this *Inspector) getCandidateUniqueKeys(tableName string) (uniqueKeys [](*sql.UniqueKey), err error) {
@ -308,7 +368,7 @@ func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err
if err != nil {
return uniqueKeys, err
}
ghostUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.GhostTableName)
ghostUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.GetGhostTableName())
if err != nil {
return uniqueKeys, err
}
@ -323,3 +383,44 @@ func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err
}
return uniqueKeys, nil
}
func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) {
visitedKeys := mysql.NewInstanceKeyMap()
return getMasterConnectionConfigSafe(this.connectionConfig, this.migrationContext.DatabaseName, visitedKeys)
}
func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, databaseName string, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) {
log.Debugf("Looking for master on %+v", connectionConfig.Key)
currentUri := connectionConfig.GetDBUri(databaseName)
db, _, err := sqlutils.GetDB(currentUri)
if err != nil {
return nil, err
}
hasMaster := false
masterConfig = connectionConfig.Duplicate()
err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error {
masterKey := mysql.InstanceKey{
Hostname: rowMap.GetString("Master_Host"),
Port: rowMap.GetInt("Master_Port"),
}
if masterKey.IsValid() {
masterConfig.Key = masterKey
hasMaster = true
}
return nil
})
if err != nil {
return nil, err
}
if hasMaster {
log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key)
if visitedKeys.HasKey(masterConfig.Key) {
return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key)
}
visitedKeys.AddKey(masterConfig.Key)
return getMasterConnectionConfigSafe(masterConfig, databaseName, visitedKeys)
}
return masterConfig, nil
}

View File

@ -6,28 +6,350 @@
package logic
import (
"github.com/github/gh-osc/go/mysql"
"fmt"
"os"
"sync/atomic"
"time"
"github.com/github/gh-osc/go/base"
"github.com/github/gh-osc/go/binlog"
"github.com/outbrain/golib/log"
)
type ChangelogState string
const (
TablesInPlace ChangelogState = "TablesInPlace"
AllEventsUpToLockProcessed = "AllEventsUpToLockProcessed"
)
type tableWriteFunc func() error
const (
applyEventsQueueBuffer = 100
)
// Migrator is the main schema migration flow manager.
type Migrator struct {
connectionConfig *mysql.ConnectionConfig
inspector *Inspector
applier *Applier
eventsStreamer *EventsStreamer
migrationContext *base.MigrationContext
tablesInPlace chan bool
rowCopyComplete chan bool
allEventsUpToLockProcessed chan bool
// copyRowsQueue should not be buffered; if buffered some non-damaging but
// excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete
copyRowsQueue chan tableWriteFunc
applyEventsQueue chan tableWriteFunc
}
func NewMigrator(connectionConfig *mysql.ConnectionConfig) *Migrator {
return &Migrator{
connectionConfig: connectionConfig,
inspector: NewInspector(connectionConfig),
func NewMigrator() *Migrator {
migrator := &Migrator{
migrationContext: base.GetMigrationContext(),
tablesInPlace: make(chan bool),
rowCopyComplete: make(chan bool),
allEventsUpToLockProcessed: make(chan bool),
copyRowsQueue: make(chan tableWriteFunc),
applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer),
}
migrator.migrationContext.IsThrottled = func() bool {
return migrator.shouldThrottle()
}
return migrator
}
func (this *Migrator) Migrate() error {
func (this *Migrator) shouldThrottle() bool {
lag := atomic.LoadInt64(&this.migrationContext.CurrentLag)
shouldThrottle := false
if time.Duration(lag) > time.Duration(this.migrationContext.MaxLagMillisecondsThrottleThreshold)*time.Millisecond {
shouldThrottle = true
} else if this.migrationContext.ThrottleFlagFile != "" {
if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil {
//Throttle file defined and exists!
shouldThrottle = true
}
}
return shouldThrottle
}
func (this *Migrator) canStopStreaming() bool {
return false
}
func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) {
// Hey, I created the changlog table, I know the type of columns it has!
if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" {
return
}
changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3))
switch changelogState {
case TablesInPlace:
{
this.tablesInPlace <- true
}
case AllEventsUpToLockProcessed:
{
this.allEventsUpToLockProcessed <- true
}
default:
{
return fmt.Errorf("Unknown changelog state: %+v", changelogState)
}
}
log.Debugf("---- - - - - - state %+v", changelogState)
return nil
}
func (this *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) {
if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "heartbeat" {
return nil
}
value := dmlEvent.NewColumnValues.StringColumn(3)
heartbeatTime, err := time.Parse(time.RFC3339, value)
if err != nil {
return log.Errore(err)
}
lag := time.Now().Sub(heartbeatTime)
atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag))
return nil
}
func (this *Migrator) Migrate() (err error) {
this.migrationContext.StartTime = time.Now()
this.inspector = NewInspector()
if err := this.inspector.InitDBConnections(); err != nil {
return err
}
if err := this.inspector.InspectTables(); err != nil {
if err := this.inspector.ValidateOriginalTable(); err != nil {
return err
}
uniqueKeys, err := this.inspector.InspectOriginalTable()
if err != nil {
return err
}
// So far so good, table is accessible and valid.
if this.migrationContext.MasterConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil {
return err
}
if this.migrationContext.IsRunningOnMaster() && !this.migrationContext.AllowedRunningOnMaster {
return fmt.Errorf("It seems like this migration attempt to run directly on master. Preferably it would be executed on a replica (and this reduces load from the master). To proceed please provide --allow-on-master")
}
log.Infof("Master found to be %+v", this.migrationContext.MasterConnectionConfig.Key)
if err := this.initiateStreaming(); err != nil {
return err
}
if err := this.initiateApplier(); err != nil {
return err
}
log.Debugf("Waiting for tables to be in place")
<-this.tablesInPlace
log.Debugf("Tables are in place")
// Yay! We now know the Ghost and Changelog tables are good to examine!
// When running on replica, this means the replica has those tables. When running
// on master this is always true, of course, and yet it also implies this knowledge
// is in the binlogs.
this.migrationContext.UniqueKey = uniqueKeys[0] // TODO. Need to wait on replica till the ghost table exists and get shared keys
if err := this.applier.ReadMigrationRangeValues(); err != nil {
return err
}
go this.initiateStatus()
go this.executeWriteFuncs()
go this.iterateChunks()
log.Debugf("Operating until row copy is complete")
<-this.rowCopyComplete
log.Debugf("Row copy complete")
this.printStatus()
throttleMigration(
this.migrationContext,
func() {
log.Debugf("throttling before LOCK TABLES")
},
nil,
func() {
log.Debugf("done throttling")
},
)
// TODO retries!!
this.applier.LockTables()
this.applier.WriteChangelog("state", string(AllEventsUpToLockProcessed))
log.Debugf("Waiting for events up to lock")
<-this.allEventsUpToLockProcessed
log.Debugf("Done waiting for events up to lock")
// TODO retries!!
this.applier.UnlockTables()
return nil
}
func (this *Migrator) initiateStatus() error {
this.printStatus()
statusTick := time.Tick(1 * time.Second)
for range statusTick {
go this.printStatus()
}
return nil
}
func (this *Migrator) printStatus() {
elapsedTime := this.migrationContext.ElapsedTime()
elapsedSeconds := int64(elapsedTime.Seconds())
totalRowsCopied := this.migrationContext.GetTotalRowsCopied()
rowsEstimate := this.migrationContext.RowsEstimate
progressPct := 100.0 * float64(totalRowsCopied) / float64(rowsEstimate)
shouldPrintStatus := false
if elapsedSeconds <= 60 {
shouldPrintStatus = true
} else if progressPct >= 99.0 {
shouldPrintStatus = true
} else if progressPct >= 95.0 {
shouldPrintStatus = (elapsedSeconds%5 == 0)
} else if elapsedSeconds <= 120 {
shouldPrintStatus = (elapsedSeconds%5 == 0)
} else {
shouldPrintStatus = (elapsedSeconds%30 == 0)
}
if !shouldPrintStatus {
return
}
status := fmt.Sprintf("Copy: %d/%d %.1f%% Backlog: %d/%d Elapsed: %+v(copy), %+v(total) ETA: N/A",
totalRowsCopied, rowsEstimate, progressPct,
len(this.applyEventsQueue), cap(this.applyEventsQueue),
this.migrationContext.ElapsedRowCopyTime(), elapsedTime)
fmt.Println(status)
}
func (this *Migrator) initiateStreaming() error {
this.eventsStreamer = NewEventsStreamer()
if err := this.eventsStreamer.InitDBConnections(); err != nil {
return err
}
this.eventsStreamer.AddListener(
false,
this.migrationContext.DatabaseName,
this.migrationContext.GetChangelogTableName(),
func(dmlEvent *binlog.BinlogDMLEvent) error {
return this.onChangelogStateEvent(dmlEvent)
},
)
this.eventsStreamer.AddListener(
false,
this.migrationContext.DatabaseName,
this.migrationContext.GetChangelogTableName(),
func(dmlEvent *binlog.BinlogDMLEvent) error {
return this.onChangelogHeartbeatEvent(dmlEvent)
},
)
go func() {
log.Debugf("Beginning streaming")
this.eventsStreamer.StreamEvents(func() bool { return this.canStopStreaming() })
}()
return nil
}
func (this *Migrator) initiateApplier() error {
this.applier = NewApplier()
if err := this.applier.InitDBConnections(); err != nil {
return err
}
if err := this.applier.CreateGhostTable(); err != nil {
log.Errorf("Unable to create ghost table, see further error details. Perhaps a previous migration failed without dropping the table? Bailing out")
return err
}
if err := this.applier.AlterGhost(); err != nil {
log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out")
return err
}
if err := this.applier.CreateChangelogTable(); err != nil {
log.Errorf("Unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out")
return err
}
this.applier.WriteChangelog("state", string(TablesInPlace))
this.applier.InitiateHeartbeat()
return nil
}
func (this *Migrator) iterateChunks() error {
this.migrationContext.RowCopyStartTime = time.Now()
terminateRowIteration := func(err error) error {
this.rowCopyComplete <- true
return log.Errore(err)
}
for {
copyRowsFunc := func() error {
hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues()
if err != nil {
return terminateRowIteration(err)
}
if !hasFurtherRange {
return terminateRowIteration(nil)
}
_, rowsAffected, _, err := this.applier.ApplyIterationInsertQuery()
if err != nil {
return terminateRowIteration(err)
}
atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected)
this.migrationContext.Iteration++
return nil
}
this.copyRowsQueue <- copyRowsFunc
}
return nil
}
func (this *Migrator) executeWriteFuncs() error {
for {
throttleMigration(
this.migrationContext,
func() {
log.Debugf("throttling writes")
},
nil,
func() {
log.Debugf("done throttling writes")
},
)
// We give higher priority to event processing, then secondary priority to
// rowcopy
select {
case applyEventFunc := <-this.applyEventsQueue:
{
retryOperation(applyEventFunc, this.migrationContext.MaxRetries())
}
default:
{
select {
case copyRowsFunc := <-this.copyRowsQueue:
{
retryOperation(copyRowsFunc, this.migrationContext.MaxRetries())
}
default:
{
// Hmmmmm... nothing in the queue; no events, but also no row copy.
// This is possible upon load. Let's just sleep it over.
log.Debugf("Getting nothing in the write queue. Sleeping...")
time.Sleep(time.Second)
}
}
}
}
}
return nil
}

162
go/logic/streamer.go Normal file
View File

@ -0,0 +1,162 @@
/*
Copyright 2016 GitHub Inc.
See https://github.com/github/gh-osc/blob/master/LICENSE
*/
package logic
import (
gosql "database/sql"
"fmt"
"strings"
"github.com/github/gh-osc/go/base"
"github.com/github/gh-osc/go/binlog"
"github.com/github/gh-osc/go/mysql"
"github.com/outbrain/golib/log"
"github.com/outbrain/golib/sqlutils"
)
type BinlogEventListener struct {
async bool
databaseName string
tableName string
onDmlEvent func(event *binlog.BinlogDMLEvent) error
}
const (
EventsChannelBufferSize = 1
)
// EventsStreamer reads data from binary logs and streams it on. It acts as a publisher,
// and interested parties may subscribe for per-table events.
type EventsStreamer struct {
connectionConfig *mysql.ConnectionConfig
db *gosql.DB
migrationContext *base.MigrationContext
nextBinlogCoordinates *mysql.BinlogCoordinates
listeners [](*BinlogEventListener)
eventsChannel chan *binlog.BinlogEntry
binlogReader binlog.BinlogReader
}
func NewEventsStreamer() *EventsStreamer {
return &EventsStreamer{
connectionConfig: base.GetMigrationContext().InspectorConnectionConfig,
migrationContext: base.GetMigrationContext(),
listeners: [](*BinlogEventListener){},
eventsChannel: make(chan *binlog.BinlogEntry, EventsChannelBufferSize),
}
}
func (this *EventsStreamer) AddListener(
async bool, databaseName string, tableName string, onDmlEvent func(event *binlog.BinlogDMLEvent) error) (err error) {
if databaseName == "" {
return fmt.Errorf("Empty database name in AddListener")
}
if tableName == "" {
return fmt.Errorf("Empty table name in AddListener")
}
listener := &BinlogEventListener{
async: async,
databaseName: databaseName,
tableName: tableName,
onDmlEvent: onDmlEvent,
}
this.listeners = append(this.listeners, listener)
return nil
}
func (this *EventsStreamer) notifyListeners(binlogEvent *binlog.BinlogDMLEvent) {
for _, listener := range this.listeners {
if strings.ToLower(listener.databaseName) != strings.ToLower(binlogEvent.DatabaseName) {
continue
}
if strings.ToLower(listener.tableName) != strings.ToLower(binlogEvent.TableName) {
continue
}
onDmlEvent := listener.onDmlEvent
if listener.async {
go func() {
onDmlEvent(binlogEvent)
}()
} else {
onDmlEvent(binlogEvent)
}
}
}
func (this *EventsStreamer) InitDBConnections() (err error) {
EventsStreamerUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
if this.db, _, err = sqlutils.GetDB(EventsStreamerUri); err != nil {
return err
}
if err := this.validateConnection(); err != nil {
return err
}
if err := this.readCurrentBinlogCoordinates(); err != nil {
return err
}
goMySQLReader, err := binlog.NewGoMySQLReader(this.migrationContext.InspectorConnectionConfig)
if err != nil {
return err
}
if err := goMySQLReader.ConnectBinlogStreamer(*this.nextBinlogCoordinates); err != nil {
return err
}
this.binlogReader = goMySQLReader
return nil
}
// validateConnection issues a simple can-connect to MySQL
func (this *EventsStreamer) validateConnection() error {
query := `select @@global.port`
var port int
if err := this.db.QueryRow(query).Scan(&port); err != nil {
return err
}
if port != this.connectionConfig.Key.Port {
return fmt.Errorf("Unexpected database port reported: %+v", port)
}
log.Infof("connection validated on %+v", this.connectionConfig.Key)
return nil
}
// validateGrants verifies the user by which we're executing has necessary grants
// to do its thang.
func (this *EventsStreamer) readCurrentBinlogCoordinates() error {
query := `show /* gh-osc readCurrentBinlogCoordinates */ master status`
foundMasterStatus := false
err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error {
this.nextBinlogCoordinates = &mysql.BinlogCoordinates{
LogFile: m.GetString("File"),
LogPos: m.GetInt64("Position"),
}
foundMasterStatus = true
return nil
})
if err != nil {
return err
}
if !foundMasterStatus {
return fmt.Errorf("Got no results from SHOW MASTER STATUS. Bailing out")
}
log.Debugf("Streamer binlog coordinates: %+v", *this.nextBinlogCoordinates)
return nil
}
// StreamEvents will begin streaming events. It will be blocking, so should be
// executed by a goroutine
func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error {
go func() {
for binlogEntry := range this.eventsChannel {
if binlogEntry.DmlEvent != nil {
this.notifyListeners(binlogEntry.DmlEvent)
}
}
}()
return this.binlogReader.StreamEvents(canStopStreaming, this.eventsChannel)
}

162
go/mysql/binlog.go Normal file
View File

@ -0,0 +1,162 @@
/*
Copyright 2015 Shlomi Noach, courtesy Booking.com
See https://github.com/github/gh-osc/blob/master/LICENSE
*/
package mysql
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
)
var detachPattern *regexp.Regexp
func init() {
detachPattern, _ = regexp.Compile(`//([^/:]+):([\d]+)`) // e.g. `//binlog.01234:567890`
}
type BinlogType int
const (
BinaryLog BinlogType = iota
RelayLog
)
// BinlogCoordinates described binary log coordinates in the form of log file & log position.
type BinlogCoordinates struct {
LogFile string
LogPos int64
Type BinlogType
}
// ParseInstanceKey will parse an InstanceKey from a string representation such as 127.0.0.1:3306
func ParseBinlogCoordinates(logFileLogPos string) (*BinlogCoordinates, error) {
tokens := strings.SplitN(logFileLogPos, ":", 2)
if len(tokens) != 2 {
return nil, fmt.Errorf("ParseBinlogCoordinates: Cannot parse BinlogCoordinates from %s. Expected format is file:pos", logFileLogPos)
}
if logPos, err := strconv.ParseInt(tokens[1], 10, 0); err != nil {
return nil, fmt.Errorf("ParseBinlogCoordinates: invalid pos: %s", tokens[1])
} else {
return &BinlogCoordinates{LogFile: tokens[0], LogPos: logPos}, nil
}
}
// DisplayString returns a user-friendly string representation of these coordinates
func (this *BinlogCoordinates) DisplayString() string {
return fmt.Sprintf("%s:%d", this.LogFile, this.LogPos)
}
// String returns a user-friendly string representation of these coordinates
func (this BinlogCoordinates) String() string {
return this.DisplayString()
}
// Equals tests equality of this corrdinate and another one.
func (this *BinlogCoordinates) Equals(other *BinlogCoordinates) bool {
if other == nil {
return false
}
return this.LogFile == other.LogFile && this.LogPos == other.LogPos && this.Type == other.Type
}
// IsEmpty returns true if the log file is empty, unnamed
func (this *BinlogCoordinates) IsEmpty() bool {
return this.LogFile == ""
}
// SmallerThan returns true if this coordinate is strictly smaller than the other.
func (this *BinlogCoordinates) SmallerThan(other *BinlogCoordinates) bool {
if this.LogFile < other.LogFile {
return true
}
if this.LogFile == other.LogFile && this.LogPos < other.LogPos {
return true
}
return false
}
// SmallerThanOrEquals returns true if this coordinate is the same or equal to the other one.
// We do NOT compare the type so we can not use this.Equals()
func (this *BinlogCoordinates) SmallerThanOrEquals(other *BinlogCoordinates) bool {
if this.SmallerThan(other) {
return true
}
return this.LogFile == other.LogFile && this.LogPos == other.LogPos // No Type comparison
}
// FileSmallerThan returns true if this coordinate's file is strictly smaller than the other's.
func (this *BinlogCoordinates) FileSmallerThan(other *BinlogCoordinates) bool {
return this.LogFile < other.LogFile
}
// FileNumberDistance returns the numeric distance between this corrdinate's file number and the other's.
// Effectively it means "how many roatets/FLUSHes would make these coordinates's file reach the other's"
func (this *BinlogCoordinates) FileNumberDistance(other *BinlogCoordinates) int {
thisNumber, _ := this.FileNumber()
otherNumber, _ := other.FileNumber()
return otherNumber - thisNumber
}
// FileNumber returns the numeric value of the file, and the length in characters representing the number in the filename.
// Example: FileNumber() of mysqld.log.000789 is (789, 6)
func (this *BinlogCoordinates) FileNumber() (int, int) {
tokens := strings.Split(this.LogFile, ".")
numPart := tokens[len(tokens)-1]
numLen := len(numPart)
fileNum, err := strconv.Atoi(numPart)
if err != nil {
return 0, 0
}
return fileNum, numLen
}
// PreviousFileCoordinatesBy guesses the filename of the previous binlog/relaylog, by given offset (number of files back)
func (this *BinlogCoordinates) PreviousFileCoordinatesBy(offset int) (BinlogCoordinates, error) {
result := BinlogCoordinates{LogPos: 0, Type: this.Type}
fileNum, numLen := this.FileNumber()
if fileNum == 0 {
return result, errors.New("Log file number is zero, cannot detect previous file")
}
newNumStr := fmt.Sprintf("%d", (fileNum - offset))
newNumStr = strings.Repeat("0", numLen-len(newNumStr)) + newNumStr
tokens := strings.Split(this.LogFile, ".")
tokens[len(tokens)-1] = newNumStr
result.LogFile = strings.Join(tokens, ".")
return result, nil
}
// PreviousFileCoordinates guesses the filename of the previous binlog/relaylog
func (this *BinlogCoordinates) PreviousFileCoordinates() (BinlogCoordinates, error) {
return this.PreviousFileCoordinatesBy(1)
}
// PreviousFileCoordinates guesses the filename of the previous binlog/relaylog
func (this *BinlogCoordinates) NextFileCoordinates() (BinlogCoordinates, error) {
result := BinlogCoordinates{LogPos: 0, Type: this.Type}
fileNum, numLen := this.FileNumber()
newNumStr := fmt.Sprintf("%d", (fileNum + 1))
newNumStr = strings.Repeat("0", numLen-len(newNumStr)) + newNumStr
tokens := strings.Split(this.LogFile, ".")
tokens[len(tokens)-1] = newNumStr
result.LogFile = strings.Join(tokens, ".")
return result, nil
}
// FileSmallerThan returns true if this coordinate's file is strictly smaller than the other's.
func (this *BinlogCoordinates) DetachedCoordinates() (isDetached bool, detachedLogFile string, detachedLogPos string) {
detachedCoordinatesSubmatch := detachPattern.FindStringSubmatch(this.LogFile)
if len(detachedCoordinatesSubmatch) == 0 {
return false, "", ""
}
return true, detachedCoordinatesSubmatch[1], detachedCoordinatesSubmatch[2]
}

View File

@ -5,10 +5,44 @@
package mysql
import (
"fmt"
)
// ConnectionConfig is the minimal configuration required to connect to a MySQL server
type ConnectionConfig struct {
Hostname string
Port int
Key InstanceKey
User string
Password string
}
func NewConnectionConfig() *ConnectionConfig {
config := &ConnectionConfig{
Key: InstanceKey{},
}
return config
}
func (this *ConnectionConfig) Duplicate() *ConnectionConfig {
config := &ConnectionConfig{
Key: InstanceKey{
Hostname: this.Key.Hostname,
Port: this.Key.Port,
},
User: this.User,
Password: this.Password,
}
return config
}
func (this *ConnectionConfig) String() string {
return fmt.Sprintf("%s, user=%s", this.Key.DisplayString(), this.User)
}
func (this *ConnectionConfig) Equals(other *ConnectionConfig) bool {
return this.Key.Equals(&other.Key)
}
func (this *ConnectionConfig) GetDBUri(databaseName string) string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", this.User, this.Password, this.Key.Hostname, this.Key.Port, databaseName)
}

115
go/mysql/instance_key.go Normal file
View File

@ -0,0 +1,115 @@
/*
Copyright 2015 Shlomi Noach, courtesy Booking.com
See https://github.com/github/gh-osc/blob/master/LICENSE
*/
package mysql
import (
"fmt"
"strconv"
"strings"
)
const (
DefaultInstancePort = 3306
)
// InstanceKey is an instance indicator, identifued by hostname and port
type InstanceKey struct {
Hostname string
Port int
}
const detachHint = "//"
// ParseInstanceKey will parse an InstanceKey from a string representation such as 127.0.0.1:3306
func NewRawInstanceKey(hostPort string) (*InstanceKey, error) {
tokens := strings.SplitN(hostPort, ":", 2)
if len(tokens) != 2 {
return nil, fmt.Errorf("Cannot parse InstanceKey from %s. Expected format is host:port", hostPort)
}
instanceKey := &InstanceKey{Hostname: tokens[0]}
var err error
if instanceKey.Port, err = strconv.Atoi(tokens[1]); err != nil {
return instanceKey, fmt.Errorf("Invalid port: %s", tokens[1])
}
return instanceKey, nil
}
// ParseRawInstanceKeyLoose will parse an InstanceKey from a string representation such as 127.0.0.1:3306.
// The port part is optional; there will be no name resolve
func ParseRawInstanceKeyLoose(hostPort string) (*InstanceKey, error) {
if !strings.Contains(hostPort, ":") {
return &InstanceKey{Hostname: hostPort, Port: DefaultInstancePort}, nil
}
return NewRawInstanceKey(hostPort)
}
// Equals tests equality between this key and another key
func (this *InstanceKey) Equals(other *InstanceKey) bool {
if other == nil {
return false
}
return this.Hostname == other.Hostname && this.Port == other.Port
}
// SmallerThan returns true if this key is dictionary-smaller than another.
// This is used for consistent sorting/ordering; there's nothing magical about it.
func (this *InstanceKey) SmallerThan(other *InstanceKey) bool {
if this.Hostname < other.Hostname {
return true
}
if this.Hostname == other.Hostname && this.Port < other.Port {
return true
}
return false
}
// IsDetached returns 'true' when this hostname is logically "detached"
func (this *InstanceKey) IsDetached() bool {
return strings.HasPrefix(this.Hostname, detachHint)
}
// IsValid uses simple heuristics to see whether this key represents an actual instance
func (this *InstanceKey) IsValid() bool {
if this.Hostname == "_" {
return false
}
if this.IsDetached() {
return false
}
return len(this.Hostname) > 0 && this.Port > 0
}
// DetachedKey returns an instance key whose hostname is detahced: invalid, but recoverable
func (this *InstanceKey) DetachedKey() *InstanceKey {
if this.IsDetached() {
return this
}
return &InstanceKey{Hostname: fmt.Sprintf("%s%s", detachHint, this.Hostname), Port: this.Port}
}
// ReattachedKey returns an instance key whose hostname is detahced: invalid, but recoverable
func (this *InstanceKey) ReattachedKey() *InstanceKey {
if !this.IsDetached() {
return this
}
return &InstanceKey{Hostname: this.Hostname[len(detachHint):], Port: this.Port}
}
// StringCode returns an official string representation of this key
func (this *InstanceKey) StringCode() string {
return fmt.Sprintf("%s:%d", this.Hostname, this.Port)
}
// DisplayString returns a user-friendly string representation of this key
func (this *InstanceKey) DisplayString() string {
return this.StringCode()
}
// String returns a user-friendly string representation of this key
func (this InstanceKey) String() string {
return this.StringCode()
}

View File

@ -0,0 +1,95 @@
/*
Copyright 2015 Shlomi Noach, courtesy Booking.com
See https://github.com/github/gh-osc/blob/master/LICENSE
*/
package mysql
import (
"encoding/json"
"strings"
)
// InstanceKeyMap is a convenience struct for listing InstanceKey-s
type InstanceKeyMap map[InstanceKey]bool
func NewInstanceKeyMap() *InstanceKeyMap {
return &InstanceKeyMap{}
}
// AddKey adds a single key to this map
func (this *InstanceKeyMap) AddKey(key InstanceKey) {
(*this)[key] = true
}
// AddKeys adds all given keys to this map
func (this *InstanceKeyMap) AddKeys(keys []InstanceKey) {
for _, key := range keys {
this.AddKey(key)
}
}
// HasKey checks if given key is within the map
func (this *InstanceKeyMap) HasKey(key InstanceKey) bool {
_, ok := (*this)[key]
return ok
}
// GetInstanceKeys returns keys in this map in the form of an array
func (this *InstanceKeyMap) GetInstanceKeys() []InstanceKey {
res := []InstanceKey{}
for key := range *this {
res = append(res, key)
}
return res
}
// MarshalJSON will marshal this map as JSON
func (this *InstanceKeyMap) MarshalJSON() ([]byte, error) {
return json.Marshal(this.GetInstanceKeys())
}
// ToJSON will marshal this map as JSON
func (this *InstanceKeyMap) ToJSON() (string, error) {
bytes, err := this.MarshalJSON()
return string(bytes), err
}
// ToJSONString will marshal this map as JSON
func (this *InstanceKeyMap) ToJSONString() string {
s, _ := this.ToJSON()
return s
}
// ToCommaDelimitedList will export this map in comma delimited format
func (this *InstanceKeyMap) ToCommaDelimitedList() string {
keyDisplays := []string{}
for key := range *this {
keyDisplays = append(keyDisplays, key.DisplayString())
}
return strings.Join(keyDisplays, ",")
}
// ReadJson unmarshalls a json into this map
func (this *InstanceKeyMap) ReadJson(jsonString string) error {
var keys []InstanceKey
err := json.Unmarshal([]byte(jsonString), &keys)
if err != nil {
return err
}
this.AddKeys(keys)
return err
}
// ReadJson unmarshalls a json into this map
func (this *InstanceKeyMap) ReadCommaDelimitedList(list string) error {
tokens := strings.Split(list, ",")
for _, token := range tokens {
key, err := ParseRawInstanceKeyLoose(token)
if err != nil {
return err
}
this.AddKey(*key)
}
return nil
}

View File

@ -64,12 +64,15 @@ func BuildEqualsComparison(columns []string, values []string) (result string, er
return result, nil
}
func BuildRangeComparison(columns []string, values []string, comparisonSign ValueComparisonSign) (result string, err error) {
func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
if len(columns) == 0 {
return "", fmt.Errorf("Got 0 columns in GetRangeComparison")
return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison")
}
if len(columns) != len(values) {
return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
return "", explodedArgs, fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
}
if len(columns) != len(args) {
return "", explodedArgs, fmt.Errorf("Got %d columns but %d args in GetEqualsComparison", len(columns), len(args))
}
includeEquals := false
if comparisonSign == LessThanOrEqualsComparisonSign {
@ -87,43 +90,47 @@ func BuildRangeComparison(columns []string, values []string, comparisonSign Valu
value := values[i]
rangeComparison, err := BuildValueComparison(column, value, comparisonSign)
if err != nil {
return "", err
return "", explodedArgs, err
}
if len(columns[0:i]) > 0 {
equalitiesComparison, err := BuildEqualsComparison(columns[0:i], values[0:i])
if err != nil {
return "", err
return "", explodedArgs, err
}
comparison := fmt.Sprintf("(%s AND %s)", equalitiesComparison, rangeComparison)
comparisons = append(comparisons, comparison)
explodedArgs = append(explodedArgs, args[0:i]...)
explodedArgs = append(explodedArgs, args[i])
} else {
comparisons = append(comparisons, rangeComparison)
explodedArgs = append(explodedArgs, args[i])
}
}
if includeEquals {
comparison, err := BuildEqualsComparison(columns, values)
if err != nil {
return "", nil
return "", explodedArgs, nil
}
comparisons = append(comparisons, comparison)
explodedArgs = append(explodedArgs, args...)
}
result = strings.Join(comparisons, " or ")
result = fmt.Sprintf("(%s)", result)
return result, nil
return result, explodedArgs, nil
}
func BuildRangePreparedComparison(columns []string, comparisonSign ValueComparisonSign) (result string, err error) {
func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
values := make([]string, len(columns), len(columns))
for i := range columns {
values[i] = "?"
}
return BuildRangeComparison(columns, values, comparisonSign)
return BuildRangeComparison(columns, values, args, comparisonSign)
}
func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string) (string, error) {
func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) {
if len(sharedColumns) == 0 {
return "", fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery")
return "", explodedArgs, fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery")
}
databaseName = EscapeName(databaseName)
originalTableName = EscapeName(originalTableName)
@ -134,50 +141,63 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin
uniqueKey = EscapeName(uniqueKey)
sharedColumnsListing := strings.Join(sharedColumns, ", ")
rangeStartComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, GreaterThanOrEqualsComparisonSign)
if err != nil {
return "", err
var minRangeComparisonSign ValueComparisonSign = GreaterThanComparisonSign
if includeRangeStartValues {
minRangeComparisonSign = GreaterThanOrEqualsComparisonSign
}
rangeEndComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, LessThanOrEqualsComparisonSign)
rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, rangeStartArgs, minRangeComparisonSign)
if err != nil {
return "", err
return "", explodedArgs, err
}
query := fmt.Sprintf(`
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, rangeEndArgs, LessThanOrEqualsComparisonSign)
if err != nil {
return "", explodedArgs, err
}
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
transactionalClause := ""
if transactionalTable {
transactionalClause = "lock in share mode"
}
result = fmt.Sprintf(`
insert /* gh-osc %s.%s */ ignore into %s.%s (%s)
(select %s from %s.%s force index (%s)
where (%s and %s)
where (%s and %s) %s
)
`, databaseName, originalTableName, databaseName, ghostTableName, sharedColumnsListing,
sharedColumnsListing, databaseName, originalTableName, uniqueKey,
rangeStartComparison, rangeEndComparison)
return query, nil
rangeStartComparison, rangeEndComparison, transactionalClause)
return result, explodedArgs, nil
}
func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string) (string, error) {
func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) {
rangeStartValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
rangeEndValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
for i := range uniqueKeyColumns {
rangeStartValues[i] = "?"
rangeEndValues[i] = "?"
}
return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues)
return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable)
}
func BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName string, uniqueKeyColumns []string, chunkSize int) (string, error) {
func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, hint string) (result string, explodedArgs []interface{}, err error) {
if len(uniqueKeyColumns) == 0 {
return "", fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery")
return "", explodedArgs, fmt.Errorf("Got 0 columns in BuildUniqueKeyRangeEndPreparedQuery")
}
databaseName = EscapeName(databaseName)
originalTableName = EscapeName(originalTableName)
tableName = EscapeName(tableName)
rangeStartComparison, err := BuildRangePreparedComparison(uniqueKeyColumns, GreaterThanComparisonSign)
rangeStartComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeStartArgs, GreaterThanComparisonSign)
if err != nil {
return "", err
return "", explodedArgs, err
}
rangeEndComparison, err := BuildRangePreparedComparison(uniqueKeyColumns, LessThanOrEqualsComparisonSign)
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
rangeEndComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeEndArgs, LessThanOrEqualsComparisonSign)
if err != nil {
return "", err
return "", explodedArgs, err
}
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
for i := range uniqueKeyColumns {
@ -185,8 +205,8 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName string,
uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumns[i])
uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumns[i])
}
query := fmt.Sprintf(`
select /* gh-osc %s.%s */ %s
result = fmt.Sprintf(`
select /* gh-osc %s.%s %s */ %s
from (
select
%s
@ -200,11 +220,45 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName string,
order by
%s
limit 1
`, databaseName, originalTableName, strings.Join(uniqueKeyColumns, ", "),
strings.Join(uniqueKeyColumns, ", "), databaseName, originalTableName,
`, databaseName, tableName, hint, strings.Join(uniqueKeyColumns, ", "),
strings.Join(uniqueKeyColumns, ", "), databaseName, tableName,
rangeStartComparison, rangeEndComparison,
strings.Join(uniqueKeyColumnAscending, ", "), chunkSize,
strings.Join(uniqueKeyColumnDescending, ", "),
)
return result, explodedArgs, nil
}
func BuildUniqueKeyMinValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) {
return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "asc")
}
func BuildUniqueKeyMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) {
return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "desc")
}
func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, order string) (string, error) {
if len(uniqueKeyColumns) == 0 {
return "", fmt.Errorf("Got 0 columns in BuildUniqueKeyMinMaxValuesPreparedQuery")
}
databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName)
uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
for i := range uniqueKeyColumns {
uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i])
uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumns[i], order)
}
query := fmt.Sprintf(`
select /* gh-osc %s.%s */ %s
from
%s.%s
order by
%s
limit 1
`, databaseName, tableName, strings.Join(uniqueKeyColumns, ", "),
databaseName, tableName,
strings.Join(uniqueKeyColumnOrder, ", "),
)
return query, nil
}

View File

@ -8,6 +8,7 @@ package sql
import (
"testing"
"reflect"
"regexp"
"strings"
@ -71,48 +72,60 @@ func TestBuildRangeComparison(t *testing.T) {
{
columns := []string{"c1"}
values := []string{"@v1"}
comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign)
args := []interface{}{3}
comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanComparisonSign)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(comparison, "((`c1` < @v1))")
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3}))
}
{
columns := []string{"c1"}
values := []string{"@v1"}
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
args := []interface{}{3}
comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or ((`c1` = @v1)))")
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3}))
}
{
columns := []string{"c1", "c2"}
values := []string{"@v1", "@v2"}
comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign)
args := []interface{}{3, 17}
comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanComparisonSign)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)))")
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17}))
}
{
columns := []string{"c1", "c2"}
values := []string{"@v1", "@v2"}
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
args := []interface{}{3, 17}
comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or ((`c1` = @v1) and (`c2` = @v2)))")
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17}))
}
{
columns := []string{"c1", "c2", "c3"}
values := []string{"@v1", "@v2", "@v3"}
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
args := []interface{}{3, 17, 22}
comparison, explodedArgs, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or (((`c1` = @v1) and (`c2` = @v2)) AND (`c3` < @v3)) or ((`c1` = @v1) and (`c2` = @v2) and (`c3` = @v3)))")
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 22, 3, 17, 22}))
}
{
columns := []string{"c1"}
values := []string{"@v1", "@v2"}
_, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
args := []interface{}{3, 17}
_, _, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign)
test.S(t).ExpectNotNil(err)
}
{
columns := []string{}
values := []string{}
_, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
args := []interface{}{}
_, _, err := BuildRangeComparison(columns, values, args, LessThanOrEqualsComparisonSign)
test.S(t).ExpectNotNil(err)
}
}
@ -127,8 +140,10 @@ func TestBuildRangeInsertQuery(t *testing.T) {
uniqueKeyColumns := []string{"id"}
rangeStartValues := []string{"@v1s"}
rangeEndValues := []string{"@v1e"}
rangeStartArgs := []interface{}{3}
rangeEndArgs := []interface{}{103}
query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues)
query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true)
test.S(t).ExpectNil(err)
expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
@ -137,14 +152,17 @@ func TestBuildRangeInsertQuery(t *testing.T) {
)
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 103, 103}))
}
{
uniqueKey := "name_position_uidx"
uniqueKeyColumns := []string{"name", "position"}
rangeStartValues := []string{"@v1s", "@v2s"}
rangeEndValues := []string{"@v1e", "@v2e"}
rangeStartArgs := []interface{}{3, 17}
rangeEndArgs := []interface{}{103, 117}
query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues)
query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true)
test.S(t).ExpectNil(err)
expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
@ -153,6 +171,7 @@ func TestBuildRangeInsertQuery(t *testing.T) {
)
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117}))
}
}
@ -164,8 +183,10 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) {
{
uniqueKey := "name_position_uidx"
uniqueKeyColumns := []string{"name", "position"}
rangeStartArgs := []interface{}{3, 17}
rangeEndArgs := []interface{}{103, 117}
query, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns)
query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true)
test.S(t).ExpectNil(err)
expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
@ -174,6 +195,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) {
)
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117}))
}
}
@ -183,11 +205,13 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) {
chunkSize := 500
{
uniqueKeyColumns := []string{"name", "position"}
rangeStartArgs := []interface{}{3, 17}
rangeEndArgs := []interface{}{103, 117}
query, err := BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName, uniqueKeyColumns, chunkSize)
query, explodedArgs, err := BuildUniqueKeyRangeEndPreparedQuery(databaseName, originalTableName, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, chunkSize, "test")
test.S(t).ExpectNil(err)
expected := `
select /* gh-osc mydb.tbl */ name, position
select /* gh-osc mydb.tbl test */ name, position
from (
select
name, position
@ -203,5 +227,38 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) {
limit 1
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 103, 103, 117, 103, 117}))
}
}
func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) {
databaseName := "mydb"
originalTableName := "tbl"
uniqueKeyColumns := []string{"name", "position"}
{
query, err := BuildUniqueKeyMinValuesPreparedQuery(databaseName, originalTableName, uniqueKeyColumns)
test.S(t).ExpectNil(err)
expected := `
select /* gh-osc mydb.tbl */ name, position
from
mydb.tbl
order by
name asc, position asc
limit 1
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
}
{
query, err := BuildUniqueKeyMaxValuesPreparedQuery(databaseName, originalTableName, uniqueKeyColumns)
test.S(t).ExpectNil(err)
expected := `
select /* gh-osc mydb.tbl */ name, position
from
mydb.tbl
order by
name desc, position desc
limit 1
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
}
}

View File

@ -35,7 +35,7 @@ type UniqueKey struct {
HasNullable bool
}
// IsPrimary cehcks if this unique key is primary
// IsPrimary checks if this unique key is primary
func (this *UniqueKey) IsPrimary() bool {
return this.Name == "PRIMARY"
}
@ -43,3 +43,52 @@ func (this *UniqueKey) IsPrimary() bool {
func (this *UniqueKey) String() string {
return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable)
}
type ColumnValues struct {
abstractValues []interface{}
ValuesPointers []interface{}
}
func NewColumnValues(length int) *ColumnValues {
result := &ColumnValues{
abstractValues: make([]interface{}, length),
ValuesPointers: make([]interface{}, length),
}
for i := 0; i < length; i++ {
result.ValuesPointers[i] = &result.abstractValues[i]
}
return result
}
func ToColumnValues(abstractValues []interface{}) *ColumnValues {
result := &ColumnValues{
abstractValues: abstractValues,
ValuesPointers: make([]interface{}, len(abstractValues)),
}
for i := 0; i < len(abstractValues); i++ {
result.ValuesPointers[i] = &result.abstractValues[i]
}
return result
}
func (this *ColumnValues) AbstractValues() []interface{} {
return this.abstractValues
}
func (this *ColumnValues) StringColumn(index int) string {
val := this.AbstractValues()[index]
if ints, ok := val.([]uint8); ok {
return string(ints)
}
return fmt.Sprintf("%+v", val)
}
func (this *ColumnValues) String() string {
stringValues := []string{}
for i := range this.AbstractValues() {
stringValues = append(stringValues, this.StringColumn(i))
}
return strings.Join(stringValues, ",")
}