Merge pull request #16 from github/ongoing-initial-work-2

ongoing development:
This commit is contained in:
Shlomi Noach 2016-04-14 13:39:48 +02:00
commit 75c3fe0bee
9 changed files with 1048 additions and 251 deletions

View File

@ -7,7 +7,9 @@ package base
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -34,6 +36,10 @@ type MigrationContext struct {
DatabaseName string DatabaseName string
OriginalTableName string OriginalTableName string
AlterStatement string AlterStatement string
Noop bool
TestOnReplica bool
TableEngine string TableEngine string
CountTableRows bool CountTableRows bool
RowsEstimate int64 RowsEstimate int64
@ -43,21 +49,31 @@ type MigrationContext struct {
OriginalBinlogRowImage string OriginalBinlogRowImage string
AllowedRunningOnMaster bool AllowedRunningOnMaster bool
InspectorConnectionConfig *mysql.ConnectionConfig InspectorConnectionConfig *mysql.ConnectionConfig
MasterConnectionConfig *mysql.ConnectionConfig ApplierConnectionConfig *mysql.ConnectionConfig
MigrationRangeMinValues *sql.ColumnValues
MigrationRangeMaxValues *sql.ColumnValues
Iteration int64
MigrationIterationRangeMinValues *sql.ColumnValues
MigrationIterationRangeMaxValues *sql.ColumnValues
UniqueKey *sql.UniqueKey
StartTime time.Time StartTime time.Time
RowCopyStartTime time.Time RowCopyStartTime time.Time
CurrentLag int64 CurrentLag int64
MaxLagMillisecondsThrottleThreshold int64 MaxLagMillisecondsThrottleThreshold int64
ThrottleFlagFile string ThrottleFlagFile string
ThrottleAdditionalFlagFile string
TotalRowsCopied int64 TotalRowsCopied int64
isThrottled bool
throttleReason string
throttleMutex *sync.Mutex
MaxLoad map[string]int64
OriginalTableColumns *sql.ColumnList
OriginalTableUniqueKeys [](*sql.UniqueKey)
GhostTableColumns *sql.ColumnList
GhostTableUniqueKeys [](*sql.UniqueKey)
UniqueKey *sql.UniqueKey
SharedColumns *sql.ColumnList
MigrationRangeMinValues *sql.ColumnValues
MigrationRangeMaxValues *sql.ColumnValues
Iteration int64
MigrationIterationRangeMinValues *sql.ColumnValues
MigrationIterationRangeMaxValues *sql.ColumnValues
IsThrottled func() bool
CanStopStreaming func() bool CanStopStreaming func() bool
} }
@ -71,8 +87,10 @@ func newMigrationContext() *MigrationContext {
return &MigrationContext{ return &MigrationContext{
ChunkSize: 1000, ChunkSize: 1000,
InspectorConnectionConfig: mysql.NewConnectionConfig(), InspectorConnectionConfig: mysql.NewConnectionConfig(),
MasterConnectionConfig: mysql.NewConnectionConfig(), ApplierConnectionConfig: mysql.NewConnectionConfig(),
MaxLagMillisecondsThrottleThreshold: 1000, MaxLagMillisecondsThrottleThreshold: 1000,
MaxLoad: make(map[string]int64),
throttleMutex: &sync.Mutex{},
} }
} }
@ -86,6 +104,11 @@ func (this *MigrationContext) GetGhostTableName() string {
return fmt.Sprintf("_%s_New", this.OriginalTableName) return fmt.Sprintf("_%s_New", this.OriginalTableName)
} }
// GetOldTableName generates the name of the "old" table, into which the original table is renamed.
func (this *MigrationContext) GetOldTableName() string {
return fmt.Sprintf("_%s_Old", this.OriginalTableName)
}
// GetChangelogTableName generates the name of changelog table, based on original table name // GetChangelogTableName generates the name of changelog table, based on original table name
func (this *MigrationContext) GetChangelogTableName() string { func (this *MigrationContext) GetChangelogTableName() string {
return fmt.Sprintf("_%s_OSC", this.OriginalTableName) return fmt.Sprintf("_%s_OSC", this.OriginalTableName)
@ -96,10 +119,11 @@ func (this *MigrationContext) RequiresBinlogFormatChange() bool {
return this.OriginalBinlogFormat != "ROW" return this.OriginalBinlogFormat != "ROW"
} }
// IsRunningOnMaster is `true` when the app connects directly to the master (typically // InspectorIsAlsoApplier is `true` when the both inspector and applier are the
// it should be executed on replica and infer the master) // same database instance. This would be true when running directly on master or when
func (this *MigrationContext) IsRunningOnMaster() bool { // testing on replica.
return this.InspectorConnectionConfig.Equals(this.MasterConnectionConfig) func (this *MigrationContext) InspectorIsAlsoApplier() bool {
return this.InspectorConnectionConfig.Equals(this.ApplierConnectionConfig)
} }
// HasMigrationRange tells us whether there's a range to iterate for copying rows. // HasMigrationRange tells us whether there's a range to iterate for copying rows.
@ -141,3 +165,42 @@ func (this *MigrationContext) ElapsedRowCopyTime() time.Duration {
func (this *MigrationContext) GetTotalRowsCopied() int64 { func (this *MigrationContext) GetTotalRowsCopied() int64 {
return atomic.LoadInt64(&this.TotalRowsCopied) return atomic.LoadInt64(&this.TotalRowsCopied)
} }
func (this *MigrationContext) GetIteration() int64 {
return atomic.LoadInt64(&this.Iteration)
}
func (this *MigrationContext) SetThrottled(throttle bool, reason string) {
this.throttleMutex.Lock()
defer func() { this.throttleMutex.Unlock() }()
this.isThrottled = throttle
this.throttleReason = reason
}
func (this *MigrationContext) IsThrottled() (bool, string) {
this.throttleMutex.Lock()
defer func() { this.throttleMutex.Unlock() }()
return this.isThrottled, this.throttleReason
}
func (this *MigrationContext) ReadMaxLoad(maxLoadList string) error {
if maxLoadList == "" {
return nil
}
maxLoadConditions := strings.Split(maxLoadList, ",")
for _, maxLoadCondition := range maxLoadConditions {
maxLoadTokens := strings.Split(maxLoadCondition, "=")
if len(maxLoadTokens) != 2 {
return fmt.Errorf("Error parsing max-load condition: %s", maxLoadCondition)
}
if maxLoadTokens[0] == "" {
return fmt.Errorf("Error parsing status variable in max-load condition: %s", maxLoadCondition)
}
if n, err := strconv.ParseInt(maxLoadTokens[1], 10, 0); err != nil {
return fmt.Errorf("Error parsing numeric value in max-load condition: %s", maxLoadCondition)
} else {
this.MaxLoad[maxLoadTokens[0]] = n
}
}
return nil
}

View File

@ -19,11 +19,6 @@ import (
func main() { func main() {
migrationContext := base.GetMigrationContext() migrationContext := base.GetMigrationContext()
// mysqlBasedir := flag.String("mysql-basedir", "", "the --basedir config for MySQL (auto-detected if not given)")
// mysqlDatadir := flag.String("mysql-datadir", "", "the --datadir config for MySQL (auto-detected if not given)")
internalExperiment := flag.Bool("internal-experiment", false, "issue an internal experiment")
binlogFile := flag.String("binlog-file", "", "Name of binary log file")
flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)") flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)")
flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)") flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
flag.StringVar(&migrationContext.InspectorConnectionConfig.User, "user", "root", "MySQL user") flag.StringVar(&migrationContext.InspectorConnectionConfig.User, "user", "root", "MySQL user")
@ -35,9 +30,20 @@ func main() {
flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)") flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)")
flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica") flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica")
flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration") executeFlag := flag.Bool("execute", false, "actually execute the alter & migrate the table. Default is noop: do some tests and exit")
flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists") flag.BoolVar(&migrationContext.TestOnReplica, "test-on-replica", false, "Have the migration run on a replica, not on the master. At the end of migration tables are not swapped; gh-osc issues `STOP SLAVE` and you can compare the two tables for building trust")
flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 100-100,000)")
if migrationContext.ChunkSize < 100 {
migrationContext.ChunkSize = 100
}
if migrationContext.ChunkSize > 100000 {
migrationContext.ChunkSize = 100000
}
flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1500, "replication lag at which to throttle operation")
flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered")
flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations")
maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'")
quiet := flag.Bool("quiet", false, "quiet") quiet := flag.Bool("quiet", false, "quiet")
verbose := flag.Bool("verbose", false, "verbose") verbose := flag.Bool("verbose", false, "verbose")
debug := flag.Bool("debug", false, "debug mode (very verbose)") debug := flag.Bool("debug", false, "debug mode (very verbose)")
@ -75,23 +81,16 @@ func main() {
if migrationContext.AlterStatement == "" { if migrationContext.AlterStatement == "" {
log.Fatalf("--alter must be provided and statement must not be empty") log.Fatalf("--alter must be provided and statement must not be empty")
} }
migrationContext.Noop = !(*executeFlag)
if migrationContext.AllowedRunningOnMaster && migrationContext.TestOnReplica {
log.Fatalf("--allow-on-master and --test-on-replica are mutually exclusive")
}
if err := migrationContext.ReadMaxLoad(*maxLoad); err != nil {
log.Fatale(err)
}
log.Info("starting gh-osc") log.Info("starting gh-osc")
if *internalExperiment {
log.Debug("starting experiment with %+v", *binlogFile)
//binlogReader = binlog.NewMySQLBinlogReader(*mysqlBasedir, *mysqlDatadir)
// binlogReader, err := binlog.NewGoMySQLReader(migrationContext.InspectorConnectionConfig)
// if err != nil {
// log.Fatale(err)
// }
// if err := binlogReader.ConnectBinlogStreamer(mysql.BinlogCoordinates{LogFile: *binlogFile, LogPos: 0}); err != nil {
// log.Fatale(err)
// }
// binlogReader.StreamEvents(func() bool { return false })
// return
}
migrator := logic.NewMigrator() migrator := logic.NewMigrator()
err := migrator.Migrate() err := migrator.Migrate()
if err != nil { if err != nil {

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/github/gh-osc/go/base" "github.com/github/gh-osc/go/base"
"github.com/github/gh-osc/go/binlog"
"github.com/github/gh-osc/go/mysql" "github.com/github/gh-osc/go/mysql"
"github.com/github/gh-osc/go/sql" "github.com/github/gh-osc/go/sql"
@ -19,10 +20,6 @@ import (
"github.com/outbrain/golib/sqlutils" "github.com/outbrain/golib/sqlutils"
) )
const (
heartbeatIntervalSeconds = 1
)
// Applier reads data from the read-MySQL-server (typically a replica, but can be the master) // Applier reads data from the read-MySQL-server (typically a replica, but can be the master)
// It is used for gaining initial status and structure, and later also follow up on progress and changelog // It is used for gaining initial status and structure, and later also follow up on progress and changelog
type Applier struct { type Applier struct {
@ -33,7 +30,7 @@ type Applier struct {
func NewApplier() *Applier { func NewApplier() *Applier {
return &Applier{ return &Applier{
connectionConfig: base.GetMigrationContext().MasterConnectionConfig, connectionConfig: base.GetMigrationContext().ApplierConnectionConfig,
migrationContext: base.GetMigrationContext(), migrationContext: base.GetMigrationContext(),
} }
} }
@ -63,7 +60,7 @@ func (this *Applier) validateConnection() error {
return nil return nil
} }
// CreateGhostTable creates the ghost table on the master // CreateGhostTable creates the ghost table on the applier host
func (this *Applier) CreateGhostTable() error { func (this *Applier) CreateGhostTable() error {
query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`, query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
@ -82,7 +79,7 @@ func (this *Applier) CreateGhostTable() error {
return nil return nil
} }
// CreateGhostTable creates the ghost table on the master // AlterGhost applies `alter` statement on ghost table
func (this *Applier) AlterGhost() error { func (this *Applier) AlterGhost() error {
query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`, query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
@ -101,16 +98,16 @@ func (this *Applier) AlterGhost() error {
return nil return nil
} }
// CreateChangelogTable creates the changelog table on the master // CreateChangelogTable creates the changelog table on the applier host
func (this *Applier) CreateChangelogTable() error { func (this *Applier) CreateChangelogTable() error {
query := fmt.Sprintf(`create /* gh-osc */ table %s.%s ( query := fmt.Sprintf(`create /* gh-osc */ table %s.%s (
id int auto_increment, id bigint auto_increment,
last_update timestamp not null DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, last_update timestamp not null DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
hint varchar(64) charset ascii not null, hint varchar(64) charset ascii not null,
value varchar(64) charset ascii not null, value varchar(255) charset ascii not null,
primary key(id), primary key(id),
unique key hint_uidx(hint) unique key hint_uidx(hint)
) auto_increment=2 ) auto_increment=256
`, `,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()), sql.EscapeName(this.migrationContext.GetChangelogTableName()),
@ -126,31 +123,50 @@ func (this *Applier) CreateChangelogTable() error {
return nil return nil
} }
// DropChangelogTable drops the changelog table on the master // dropTable drops a given table on the applied host
func (this *Applier) DropChangelogTable() error { func (this *Applier) dropTable(tableName string) error {
query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`, query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()), sql.EscapeName(tableName),
) )
log.Infof("Droppping changelog table %s.%s", log.Infof("Droppping table %s.%s",
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()), sql.EscapeName(tableName),
) )
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err return err
} }
log.Infof("Changelog table dropped") log.Infof("Table dropped")
return nil return nil
} }
// DropChangelogTable drops the changelog table on the applier host
func (this *Applier) DropChangelogTable() error {
return this.dropTable(this.migrationContext.GetChangelogTableName())
}
// DropGhostTable drops the ghost table on the applier host
func (this *Applier) DropGhostTable() error {
return this.dropTable(this.migrationContext.GetGhostTableName())
}
// WriteChangelog writes a value to the changelog table. // WriteChangelog writes a value to the changelog table.
// It returns the hint as given, for convenience // It returns the hint as given, for convenience
func (this *Applier) WriteChangelog(hint, value string) (string, error) { func (this *Applier) WriteChangelog(hint, value string) (string, error) {
explicitId := 0
switch hint {
case "heartbeat":
explicitId = 1
case "state":
explicitId = 2
case "throttle":
explicitId = 3
}
query := fmt.Sprintf(` query := fmt.Sprintf(`
insert /* gh-osc */ into %s.%s insert /* gh-osc */ into %s.%s
(id, hint, value) (id, hint, value)
values values
(NULL, ?, ?) (NULLIF(?, 0), ?, ?)
on duplicate key update on duplicate key update
last_update=NOW(), last_update=NOW(),
value=VALUES(value) value=VALUES(value)
@ -158,29 +174,25 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) {
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()), sql.EscapeName(this.migrationContext.GetChangelogTableName()),
) )
_, err := sqlutils.Exec(this.db, query, hint, value) _, err := sqlutils.Exec(this.db, query, explicitId, hint, value)
return hint, err return hint, err
} }
func (this *Applier) WriteAndLogChangelog(hint, value string) (string, error) {
this.WriteChangelog(hint, value)
return this.WriteChangelog(fmt.Sprintf("%s at %d", hint, time.Now().UnixNano()), value)
}
func (this *Applier) WriteChangelogState(value string) (string, error) {
return this.WriteAndLogChangelog("state", value)
}
// InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table.
// This is done asynchronously // This is done asynchronously
func (this *Applier) InitiateHeartbeat() { func (this *Applier) InitiateHeartbeat(heartbeatIntervalMilliseconds int64) {
go func() {
numSuccessiveFailures := 0 numSuccessiveFailures := 0
query := fmt.Sprintf(`
insert /* gh-osc */ into %s.%s
(id, hint, value)
values
(1, 'heartbeat', ?)
on duplicate key update
last_update=NOW(),
value=VALUES(value)
`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
injectHeartbeat := func() error { injectHeartbeat := func() error {
if _, err := sqlutils.ExecNoPrepare(this.db, query, time.Now().Format(time.RFC3339)); err != nil { if _, err := this.WriteChangelog("heartbeat", time.Now().Format(time.RFC3339)); err != nil {
numSuccessiveFailures++ numSuccessiveFailures++
if numSuccessiveFailures > this.migrationContext.MaxRetries() { if numSuccessiveFailures > this.migrationContext.MaxRetries() {
return log.Errore(err) return log.Errore(err)
@ -192,7 +204,7 @@ func (this *Applier) InitiateHeartbeat() {
} }
injectHeartbeat() injectHeartbeat()
heartbeatTick := time.Tick(time.Duration(heartbeatIntervalSeconds) * time.Second) heartbeatTick := time.Tick(time.Duration(heartbeatIntervalMilliseconds) * time.Millisecond)
for range heartbeatTick { for range heartbeatTick {
// Generally speaking, we would issue a goroutine, but I'd actually rather // Generally speaking, we would issue a goroutine, but I'd actually rather
// have this blocked rather than spam the master in the event something // have this blocked rather than spam the master in the event something
@ -201,13 +213,12 @@ func (this *Applier) InitiateHeartbeat() {
return return
} }
} }
}()
} }
// ReadMigrationMinValues // ReadMigrationMinValues
func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error {
log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) log.Debugf("Reading migration range according to key: %s", uniqueKey.Name)
query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names)
if err != nil { if err != nil {
return err return err
} }
@ -216,7 +227,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error {
return err return err
} }
for rows.Next() { for rows.Next() {
this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(len(uniqueKey.Columns)) this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(uniqueKey.Len())
if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil { if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil {
return err return err
} }
@ -228,7 +239,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error {
// ReadMigrationMinValues // ReadMigrationMinValues
func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error {
log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) log.Debugf("Reading migration range according to key: %s", uniqueKey.Name)
query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names)
if err != nil { if err != nil {
return err return err
} }
@ -237,7 +248,7 @@ func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error {
return err return err
} }
for rows.Next() { for rows.Next() {
this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(len(uniqueKey.Columns)) this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(uniqueKey.Len())
if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil { if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil {
return err return err
} }
@ -266,12 +277,12 @@ func (this *Applier) __unused_IterationIsComplete() (bool, error) {
return false, nil return false, nil
} }
args := sqlutils.Args() args := sqlutils.Args()
compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign)
if err != nil { if err != nil {
return false, err return false, err
} }
args = append(args, explodedArgs...) args = append(args, explodedArgs...)
compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -311,11 +322,11 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo
query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery(
this.migrationContext.DatabaseName, this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableName,
this.migrationContext.UniqueKey.Columns, this.migrationContext.UniqueKey.Columns.Names,
this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(),
this.migrationContext.MigrationRangeMaxValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(),
this.migrationContext.ChunkSize, this.migrationContext.ChunkSize,
fmt.Sprintf("iteration:%d", this.migrationContext.Iteration), fmt.Sprintf("iteration:%d", this.migrationContext.GetIteration()),
) )
if err != nil { if err != nil {
return hasFurtherRange, err return hasFurtherRange, err
@ -324,7 +335,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo
if err != nil { if err != nil {
return hasFurtherRange, err return hasFurtherRange, err
} }
iterationRangeMaxValues := sql.NewColumnValues(len(this.migrationContext.UniqueKey.Columns)) iterationRangeMaxValues := sql.NewColumnValues(this.migrationContext.UniqueKey.Len())
for rows.Next() { for rows.Next() {
if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil { if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil {
return hasFurtherRange, err return hasFurtherRange, err
@ -332,17 +343,10 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo
hasFurtherRange = true hasFurtherRange = true
} }
if !hasFurtherRange { if !hasFurtherRange {
log.Debugf("Iteration complete: cannot find iteration end") log.Debugf("Iteration complete: no further range to iterate")
return hasFurtherRange, nil return hasFurtherRange, nil
} }
this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues
log.Debugf(
"column values: [%s]..[%s]; iteration: %d; chunk-size: %d",
this.migrationContext.MigrationIterationRangeMinValues,
this.migrationContext.MigrationIterationRangeMaxValues,
this.migrationContext.Iteration,
this.migrationContext.ChunkSize,
)
return hasFurtherRange, nil return hasFurtherRange, nil
} }
@ -354,12 +358,12 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected
this.migrationContext.DatabaseName, this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableName,
this.migrationContext.GetGhostTableName(), this.migrationContext.GetGhostTableName(),
this.migrationContext.UniqueKey.Columns, this.migrationContext.SharedColumns.Names,
this.migrationContext.UniqueKey.Name, this.migrationContext.UniqueKey.Name,
this.migrationContext.UniqueKey.Columns, this.migrationContext.UniqueKey.Columns.Names,
this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(),
this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(),
this.migrationContext.Iteration == 0, this.migrationContext.GetIteration() == 0,
this.migrationContext.IsTransactionalTable(), this.migrationContext.IsTransactionalTable(),
) )
if err != nil { if err != nil {
@ -371,21 +375,22 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected
} }
rowsAffected, _ = sqlResult.RowsAffected() rowsAffected, _ = sqlResult.RowsAffected()
duration = time.Now().Sub(startTime) duration = time.Now().Sub(startTime)
this.WriteChangelog(
fmt.Sprintf("copy iteration %d", this.migrationContext.Iteration),
fmt.Sprintf("chunk: %d; affected: %d; duration: %d", chunkSize, rowsAffected, duration),
)
log.Debugf( log.Debugf(
"Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d", "Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d",
this.migrationContext.MigrationIterationRangeMinValues, this.migrationContext.MigrationIterationRangeMinValues,
this.migrationContext.MigrationIterationRangeMaxValues, this.migrationContext.MigrationIterationRangeMaxValues,
this.migrationContext.Iteration, this.migrationContext.GetIteration(),
chunkSize) chunkSize)
return chunkSize, rowsAffected, duration, nil return chunkSize, rowsAffected, duration, nil
} }
// LockTables // LockTables
func (this *Applier) LockTables() error { func (this *Applier) LockTables() error {
if this.migrationContext.Noop {
log.Debugf("Noop operation; not really locking tables")
return nil
}
query := fmt.Sprintf(`lock /* gh-osc */ tables %s.%s write, %s.%s write, %s.%s write`, query := fmt.Sprintf(`lock /* gh-osc */ tables %s.%s write, %s.%s write, %s.%s write`,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.OriginalTableName), sql.EscapeName(this.migrationContext.OriginalTableName),
@ -412,3 +417,44 @@ func (this *Applier) UnlockTables() error {
log.Infof("Tables unlocked") log.Infof("Tables unlocked")
return nil return nil
} }
func (this *Applier) ShowStatusVariable(variableName string) (result int64, err error) {
query := fmt.Sprintf(`show global status like '%s'`, variableName)
if err := this.db.QueryRow(query).Scan(&variableName, &result); err != nil {
return 0, err
}
return result, nil
}
func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (query string, args []interface{}, err error) {
switch dmlEvent.DML {
case binlog.DeleteDML:
{
query, uniqueKeyArgs, err := sql.BuildDMLDeleteQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.WhereColumnValues.AbstractValues())
return query, uniqueKeyArgs, err
}
case binlog.InsertDML:
{
query, sharedArgs, err := sql.BuildDMLInsertQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, dmlEvent.NewColumnValues.AbstractValues())
return query, sharedArgs, err
}
case binlog.UpdateDML:
{
query, sharedArgs, uniqueKeyArgs, err := sql.BuildDMLUpdateQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues())
args = append(args, sharedArgs...)
args = append(args, uniqueKeyArgs...)
return query, args, err
}
}
return "", args, fmt.Errorf("Unknown dml event type: %+v", dmlEvent.DML)
}
func (this *Applier) ApplyDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) error {
query, args, err := this.buildDMLEventQuery(dmlEvent)
if err != nil {
return err
}
log.Errorf(query)
_, err = sqlutils.Exec(this.db, query, args...)
return err
}

View File

@ -44,6 +44,9 @@ func (this *Inspector) InitDBConnections() (err error) {
if err := this.validateGrants(); err != nil { if err := this.validateGrants(); err != nil {
return err return err
} }
if err := this.restartReplication(); err != nil {
return err
}
if err := this.validateBinlogs(); err != nil { if err := this.validateBinlogs(); err != nil {
return err return err
} }
@ -69,15 +72,54 @@ func (this *Inspector) ValidateOriginalTable() (err error) {
return nil return nil
} }
func (this *Inspector) InspectOriginalTable() (uniqueKeys [](*sql.UniqueKey), err error) { func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) {
uniqueKeys, err = this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName) uniqueKeys, err = this.getCandidateUniqueKeys(tableName)
if err != nil { if err != nil {
return uniqueKeys, err return columns, uniqueKeys, err
} }
if len(uniqueKeys) == 0 { if len(uniqueKeys) == 0 {
return uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") return columns, uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out")
} }
return uniqueKeys, err columns, err = this.getTableColumns(this.migrationContext.DatabaseName, tableName)
if err != nil {
return columns, uniqueKeys, err
}
return columns, uniqueKeys, nil
}
func (this *Inspector) InspectOriginalTable() (err error) {
this.migrationContext.OriginalTableColumns, this.migrationContext.OriginalTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.OriginalTableName)
if err == nil {
return err
}
return nil
}
func (this *Inspector) InspectOriginalAndGhostTables() (err error) {
this.migrationContext.GhostTableColumns, this.migrationContext.GhostTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.GetGhostTableName())
if err != nil {
return err
}
sharedUniqueKeys, err := this.getSharedUniqueKeys(this.migrationContext.OriginalTableUniqueKeys, this.migrationContext.GhostTableUniqueKeys)
if err != nil {
return err
}
if len(sharedUniqueKeys) == 0 {
return fmt.Errorf("No shared unique key can be found after ALTER! Bailing out")
}
this.migrationContext.UniqueKey = sharedUniqueKeys[0]
log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name)
if !this.migrationContext.UniqueKey.IsPrimary() {
if this.migrationContext.OriginalBinlogRowImage != "full" {
return fmt.Errorf("binlog_row_image is '%s' and chosen key is %s, which is not the primary key. This operation cannot proceed. You may `set global binlog_row_image='full'` and try again")
}
}
this.migrationContext.SharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns)
log.Infof("Shared columns are %s", this.migrationContext.SharedColumns)
// By fact that a non-empty unique key exists we also know the shared columns are non-empty
return nil
} }
// validateConnection issues a simple can-connect to MySQL // validateConnection issues a simple can-connect to MySQL
@ -136,6 +178,32 @@ func (this *Inspector) validateGrants() error {
return log.Errorf("User has insufficient privileges for migration.") return log.Errorf("User has insufficient privileges for migration.")
} }
// restartReplication is required so that we are _certain_ the binlog format and
// row image settings have actually been applied to the replication thread.
// It is entriely possible, for example, that the replication is using 'STATEMENT'
// binlog format even as the variable says 'ROW'
func (this *Inspector) restartReplication() error {
log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port)
masterKey, _ := getMasterKeyFromSlaveStatus(this.connectionConfig)
if masterKey == nil {
// This is not a replica
return nil
}
var stopError, startError error
_, stopError = sqlutils.ExecNoPrepare(this.db, `stop slave`)
_, startError = sqlutils.ExecNoPrepare(this.db, `start slave`)
if stopError != nil {
return stopError
}
if startError != nil {
return startError
}
log.Debugf("Replication restarted")
return nil
}
// validateBinlogs checks that binary log configuration is good to go // validateBinlogs checks that binary log configuration is good to go
func (this *Inspector) validateBinlogs() error { func (this *Inspector) validateBinlogs() error {
query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format` query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format`
@ -264,27 +332,28 @@ func (this *Inspector) countTableRows() error {
return nil return nil
} }
func (this *Inspector) getTableColumns(databaseName, tableName string) (columns sql.ColumnList, err error) { func (this *Inspector) getTableColumns(databaseName, tableName string) (*sql.ColumnList, error) {
query := fmt.Sprintf(` query := fmt.Sprintf(`
show columns from %s.%s show columns from %s.%s
`, `,
sql.EscapeName(databaseName), sql.EscapeName(databaseName),
sql.EscapeName(tableName), sql.EscapeName(tableName),
) )
err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { columnNames := []string{}
columns = append(columns, rowMap.GetString("Field")) err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error {
columnNames = append(columnNames, rowMap.GetString("Field"))
return nil return nil
}) })
if err != nil { if err != nil {
return columns, err return nil, err
} }
if len(columns) == 0 { if len(columnNames) == 0 {
return columns, log.Errorf("Found 0 columns on %s.%s. Bailing out", return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out",
sql.EscapeName(databaseName), sql.EscapeName(databaseName),
sql.EscapeName(tableName), sql.EscapeName(tableName),
) )
} }
return columns, nil return sql.NewColumnList(columnNames), nil
} }
// getCandidateUniqueKeys investigates a table and returns the list of unique keys // getCandidateUniqueKeys investigates a table and returns the list of unique keys
@ -361,17 +430,9 @@ func (this *Inspector) getCandidateUniqueKeys(tableName string) (uniqueKeys [](*
return uniqueKeys, nil return uniqueKeys, nil
} }
// getCandidateUniqueKeys investigates a table and returns the list of unique keys // getSharedUniqueKeys returns the intersection of two given unique keys,
// candidate for chunking // testing by list of columns
func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err error) { func (this *Inspector) getSharedUniqueKeys(originalUniqueKeys, ghostUniqueKeys [](*sql.UniqueKey)) (uniqueKeys [](*sql.UniqueKey), err error) {
originalUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName)
if err != nil {
return uniqueKeys, err
}
ghostUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.GetGhostTableName())
if err != nil {
return uniqueKeys, err
}
// We actually do NOT rely on key name, just on the set of columns. This is because maybe // We actually do NOT rely on key name, just on the set of columns. This is because maybe
// the ALTER is on the name itself... // the ALTER is on the name itself...
for _, originalUniqueKey := range originalUniqueKeys { for _, originalUniqueKey := range originalUniqueKeys {
@ -384,43 +445,77 @@ func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err
return uniqueKeys, nil return uniqueKeys, nil
} }
func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { // getSharedColumns returns the intersection of two lists of columns in same order as the first list
visitedKeys := mysql.NewInstanceKeyMap() func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.ColumnList) *sql.ColumnList {
return getMasterConnectionConfigSafe(this.connectionConfig, this.migrationContext.DatabaseName, visitedKeys) columnsInGhost := make(map[string]bool)
for _, ghostColumn := range ghostColumns.Names {
columnsInGhost[ghostColumn] = true
}
sharedColumnNames := []string{}
for _, originalColumn := range originalColumns.Names {
if columnsInGhost[originalColumn] {
sharedColumnNames = append(sharedColumnNames, originalColumn)
}
}
return sql.NewColumnList(sharedColumnNames)
} }
func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, databaseName string, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { func (this *Inspector) readChangelogState() (map[string]string, error) {
log.Debugf("Looking for master on %+v", connectionConfig.Key) query := fmt.Sprintf(`
select hint, value from %s.%s where id <= 255
`,
sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()),
)
result := make(map[string]string)
err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error {
result[m.GetString("hint")] = m.GetString("value")
return nil
})
return result, err
}
currentUri := connectionConfig.GetDBUri(databaseName) func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) {
visitedKeys := mysql.NewInstanceKeyMap()
return getMasterConnectionConfigSafe(this.connectionConfig, visitedKeys)
}
func getMasterKeyFromSlaveStatus(connectionConfig *mysql.ConnectionConfig) (masterKey *mysql.InstanceKey, err error) {
currentUri := connectionConfig.GetDBUri("information_schema")
db, _, err := sqlutils.GetDB(currentUri) db, _, err := sqlutils.GetDB(currentUri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hasMaster := false
masterConfig = connectionConfig.Duplicate()
err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error {
masterKey := mysql.InstanceKey{ masterKey = &mysql.InstanceKey{
Hostname: rowMap.GetString("Master_Host"), Hostname: rowMap.GetString("Master_Host"),
Port: rowMap.GetInt("Master_Port"), Port: rowMap.GetInt("Master_Port"),
} }
if masterKey.IsValid() {
masterConfig.Key = masterKey
hasMaster = true
}
return nil return nil
}) })
return masterKey, err
}
func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) {
log.Debugf("Looking for master on %+v", connectionConfig.Key)
masterKey, err := getMasterKeyFromSlaveStatus(connectionConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if hasMaster { if masterKey == nil {
return connectionConfig, nil
}
if !masterKey.IsValid() {
return connectionConfig, nil
}
masterConfig = connectionConfig.Duplicate()
masterConfig.Key = *masterKey
log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key)
if visitedKeys.HasKey(masterConfig.Key) { if visitedKeys.HasKey(masterConfig.Key) {
return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key)
} }
visitedKeys.AddKey(masterConfig.Key) visitedKeys.AddKey(masterConfig.Key)
return getMasterConnectionConfigSafe(masterConfig, databaseName, visitedKeys) return getMasterConnectionConfigSafe(masterConfig, visitedKeys)
}
return masterConfig, nil
} }

View File

@ -8,7 +8,10 @@ package logic
import ( import (
"fmt" "fmt"
"os" "os"
"os/signal"
"regexp"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"github.com/github/gh-osc/go/base" "github.com/github/gh-osc/go/base"
@ -28,6 +31,11 @@ type tableWriteFunc func() error
const ( const (
applyEventsQueueBuffer = 100 applyEventsQueueBuffer = 100
heartbeatIntervalMilliseconds = 1000
)
var (
prettifyDurationRegexp = regexp.MustCompile("([.][0-9]+)")
) )
// Migrator is the main schema migration flow manager. // Migrator is the main schema migration flow manager.
@ -45,6 +53,8 @@ type Migrator struct {
// excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete
copyRowsQueue chan tableWriteFunc copyRowsQueue chan tableWriteFunc
applyEventsQueue chan tableWriteFunc applyEventsQueue chan tableWriteFunc
handledChangelogStates map[string]bool
} }
func NewMigrator() *Migrator { func NewMigrator() *Migrator {
@ -56,38 +66,136 @@ func NewMigrator() *Migrator {
copyRowsQueue: make(chan tableWriteFunc), copyRowsQueue: make(chan tableWriteFunc),
applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer),
} handledChangelogStates: make(map[string]bool),
migrator.migrationContext.IsThrottled = func() bool {
return migrator.shouldThrottle()
} }
return migrator return migrator
} }
func (this *Migrator) shouldThrottle() bool { func prettifyDurationOutput(d time.Duration) string {
if d < time.Second {
return "0s"
}
result := fmt.Sprintf("%s", d)
result = prettifyDurationRegexp.ReplaceAllString(result, "")
return result
}
// acceptSignals registers for OS signals
func (this *Migrator) acceptSignals() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGHUP)
go func() {
for sig := range c {
switch sig {
case syscall.SIGHUP:
log.Debugf("Received SIGHUP. Reloading configuration")
}
}
}()
}
func (this *Migrator) shouldThrottle() (result bool, reason string) {
lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) lag := atomic.LoadInt64(&this.migrationContext.CurrentLag)
shouldThrottle := false
if time.Duration(lag) > time.Duration(this.migrationContext.MaxLagMillisecondsThrottleThreshold)*time.Millisecond { if time.Duration(lag) > time.Duration(this.migrationContext.MaxLagMillisecondsThrottleThreshold)*time.Millisecond {
shouldThrottle = true return true, fmt.Sprintf("lag=%fs", time.Duration(lag).Seconds())
} else if this.migrationContext.ThrottleFlagFile != "" { }
if this.migrationContext.ThrottleFlagFile != "" {
if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil { if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil {
// Throttle file defined and exists! // Throttle file defined and exists!
shouldThrottle = true return true, "flag-file"
} }
} }
return shouldThrottle if this.migrationContext.ThrottleAdditionalFlagFile != "" {
if _, err := os.Stat(this.migrationContext.ThrottleAdditionalFlagFile); err == nil {
// 2nd Throttle file defined and exists!
return true, "flag-file"
}
}
for variableName, threshold := range this.migrationContext.MaxLoad {
value, err := this.applier.ShowStatusVariable(variableName)
if err != nil {
return true, fmt.Sprintf("%s %s", variableName, err)
}
if value > threshold {
return true, fmt.Sprintf("%s=%d", variableName, value)
}
}
return false, ""
}
func (this *Migrator) initiateThrottler() error {
throttlerTick := time.Tick(1 * time.Second)
throttlerFunction := func() {
alreadyThrottling, currentReason := this.migrationContext.IsThrottled()
shouldThrottle, throttleReason := this.shouldThrottle()
if shouldThrottle && !alreadyThrottling {
// New throttling
this.applier.WriteAndLogChangelog("throttle", throttleReason)
} else if shouldThrottle && alreadyThrottling && (currentReason != throttleReason) {
// Change of reason
this.applier.WriteAndLogChangelog("throttle", throttleReason)
} else if alreadyThrottling && !shouldThrottle {
// End of throttling
this.applier.WriteAndLogChangelog("throttle", "done throttling")
}
this.migrationContext.SetThrottled(shouldThrottle, throttleReason)
}
throttlerFunction()
for range throttlerTick {
throttlerFunction()
}
return nil
}
// throttle initiates a throttling event, if need be, updates the Context and
// calls callback functions, if any
func (this *Migrator) throttle(onThrottled func()) {
for {
if shouldThrottle, _ := this.migrationContext.IsThrottled(); !shouldThrottle {
return
}
if onThrottled != nil {
onThrottled()
}
time.Sleep(time.Second)
}
}
// retryOperation attempts up to `count` attempts at running given function,
// exiting as soon as it returns with non-error.
func (this *Migrator) retryOperation(operation func() error) (err error) {
maxRetries := this.migrationContext.MaxRetries()
for i := 0; i < maxRetries; i++ {
if i != 0 {
// sleep after previous iteration
time.Sleep(1 * time.Second)
}
err = operation()
if err == nil {
return nil
}
// there's an error. Let's try again.
}
return err
} }
func (this *Migrator) canStopStreaming() bool { func (this *Migrator) canStopStreaming() bool {
return false return false
} }
func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { func (this *Migrator) onChangelogState(stateValue string) (err error) {
// Hey, I created the changlog table, I know the type of columns it has! if this.handledChangelogStates[stateValue] {
if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" {
return return
} }
changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) this.handledChangelogStates[stateValue] = true
changelogState := ChangelogState(stateValue)
switch changelogState { switch changelogState {
case TablesInPlace: case TablesInPlace:
{ {
@ -102,16 +210,12 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er
return fmt.Errorf("Unknown changelog state: %+v", changelogState) return fmt.Errorf("Unknown changelog state: %+v", changelogState)
} }
} }
log.Debugf("---- - - - - - state %+v", changelogState) log.Debugf("Received state %+v", changelogState)
return nil return nil
} }
func (this *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { func (this *Migrator) onChangelogHeartbeat(heartbeatValue string) (err error) {
if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "heartbeat" { heartbeatTime, err := time.Parse(time.RFC3339, heartbeatValue)
return nil
}
value := dmlEvent.NewColumnValues.StringColumn(3)
heartbeatTime, err := time.Parse(time.RFC3339, value)
if err != nil { if err != nil {
return log.Errore(err) return log.Errore(err)
} }
@ -132,18 +236,29 @@ func (this *Migrator) Migrate() (err error) {
if err := this.inspector.ValidateOriginalTable(); err != nil { if err := this.inspector.ValidateOriginalTable(); err != nil {
return err return err
} }
uniqueKeys, err := this.inspector.InspectOriginalTable() if err := this.inspector.InspectOriginalTable(); err != nil {
if err != nil {
return err return err
} }
// So far so good, table is accessible and valid. // So far so good, table is accessible and valid.
if this.migrationContext.MasterConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { // Let's get master connection config
if this.migrationContext.ApplierConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil {
return err return err
} }
if this.migrationContext.IsRunningOnMaster() && !this.migrationContext.AllowedRunningOnMaster { if this.migrationContext.TestOnReplica {
if this.migrationContext.InspectorIsAlsoApplier() {
return fmt.Errorf("Instructed to --test-on-replica, but the server we connect to doesn't seem to be a replica")
}
log.Infof("--test-on-replica given. Will not execute on master %+v but rather on replica %+v itself",
this.migrationContext.ApplierConnectionConfig.Key, this.migrationContext.InspectorConnectionConfig.Key,
)
this.migrationContext.ApplierConnectionConfig = this.migrationContext.InspectorConnectionConfig.Duplicate()
} else if this.migrationContext.InspectorIsAlsoApplier() && !this.migrationContext.AllowedRunningOnMaster {
return fmt.Errorf("It seems like this migration attempt to run directly on master. Preferably it would be executed on a replica (and this reduces load from the master). To proceed please provide --allow-on-master") return fmt.Errorf("It seems like this migration attempt to run directly on master. Preferably it would be executed on a replica (and this reduces load from the master). To proceed please provide --allow-on-master")
} }
log.Infof("Master found to be %+v", this.migrationContext.MasterConnectionConfig.Key)
log.Infof("Master found to be %+v", this.migrationContext.ApplierConnectionConfig.Key)
go this.initiateChangelogListener()
if err := this.initiateStreaming(); err != nil { if err := this.initiateStreaming(); err != nil {
return err return err
@ -159,33 +274,30 @@ func (this *Migrator) Migrate() (err error) {
// When running on replica, this means the replica has those tables. When running // When running on replica, this means the replica has those tables. When running
// on master this is always true, of course, and yet it also implies this knowledge // on master this is always true, of course, and yet it also implies this knowledge
// is in the binlogs. // is in the binlogs.
if err := this.inspector.InspectOriginalAndGhostTables(); err != nil {
return err
}
this.migrationContext.UniqueKey = uniqueKeys[0] // TODO. Need to wait on replica till the ghost table exists and get shared keys
if err := this.applier.ReadMigrationRangeValues(); err != nil { if err := this.applier.ReadMigrationRangeValues(); err != nil {
return err return err
} }
go this.initiateStatus() go this.initiateThrottler()
go this.executeWriteFuncs() go this.executeWriteFuncs()
go this.iterateChunks() go this.iterateChunks()
this.migrationContext.RowCopyStartTime = time.Now()
go this.initiateStatus()
log.Debugf("Operating until row copy is complete") log.Debugf("Operating until row copy is complete")
<-this.rowCopyComplete <-this.rowCopyComplete
log.Debugf("Row copy complete") log.Debugf("Row copy complete")
this.printStatus() this.printStatus()
throttleMigration( this.throttle(func() {
this.migrationContext, log.Debugf("throttling on LOCK TABLES")
func() { })
log.Debugf("throttling before LOCK TABLES")
},
nil,
func() {
log.Debugf("done throttling")
},
)
// TODO retries!! // TODO retries!!
this.applier.LockTables() this.applier.LockTables()
this.applier.WriteChangelog("state", string(AllEventsUpToLockProcessed)) this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed))
log.Debugf("Waiting for events up to lock") log.Debugf("Waiting for events up to lock")
<-this.allEventsUpToLockProcessed <-this.allEventsUpToLockProcessed
log.Debugf("Done waiting for events up to lock") log.Debugf("Done waiting for events up to lock")
@ -228,34 +340,71 @@ func (this *Migrator) printStatus() {
return return
} }
status := fmt.Sprintf("Copy: %d/%d %.1f%% Backlog: %d/%d Elapsed: %+v(copy), %+v(total) ETA: N/A", eta := "N/A"
if isThrottled, throttleReason := this.migrationContext.IsThrottled(); isThrottled {
eta = fmt.Sprintf("throttled, %s", throttleReason)
}
status := fmt.Sprintf("Copy: %d/%d %.1f%%; Backlog: %d/%d; Elapsed: %+v(copy), %+v(total); ETA: %s",
totalRowsCopied, rowsEstimate, progressPct, totalRowsCopied, rowsEstimate, progressPct,
len(this.applyEventsQueue), cap(this.applyEventsQueue), len(this.applyEventsQueue), cap(this.applyEventsQueue),
this.migrationContext.ElapsedRowCopyTime(), elapsedTime) prettifyDurationOutput(this.migrationContext.ElapsedRowCopyTime()), prettifyDurationOutput(elapsedTime),
eta,
)
this.applier.WriteChangelog(
fmt.Sprintf("copy iteration %d at %d", this.migrationContext.GetIteration(), time.Now().Unix()),
status,
)
fmt.Println(status) fmt.Println(status)
} }
func (this *Migrator) initiateChangelogListener() {
ticker := time.Tick((heartbeatIntervalMilliseconds * time.Millisecond) / 2)
for range ticker {
go func() error {
changelogState, err := this.inspector.readChangelogState()
if err != nil {
return log.Errore(err)
}
for hint, value := range changelogState {
switch hint {
case "state":
{
this.onChangelogState(value)
}
case "heartbeat":
{
this.onChangelogHeartbeat(value)
}
}
}
return nil
}()
}
}
// initiateStreaming begins treaming of binary log events and registers listeners for such events
func (this *Migrator) initiateStreaming() error { func (this *Migrator) initiateStreaming() error {
this.eventsStreamer = NewEventsStreamer() this.eventsStreamer = NewEventsStreamer()
if err := this.eventsStreamer.InitDBConnections(); err != nil { if err := this.eventsStreamer.InitDBConnections(); err != nil {
return err return err
} }
if this.migrationContext.Noop {
log.Debugf("Noop operation; not really listening on binlog events")
return nil
}
this.eventsStreamer.AddListener( this.eventsStreamer.AddListener(
false, true,
this.migrationContext.DatabaseName, this.migrationContext.DatabaseName,
this.migrationContext.GetChangelogTableName(), this.migrationContext.OriginalTableName,
func(dmlEvent *binlog.BinlogDMLEvent) error { func(dmlEvent *binlog.BinlogDMLEvent) error {
return this.onChangelogStateEvent(dmlEvent) applyEventFunc := func() error {
}, return this.applier.ApplyDMLEventQuery(dmlEvent)
) }
this.eventsStreamer.AddListener( this.applyEventsQueue <- applyEventFunc
false, return nil
this.migrationContext.DatabaseName,
this.migrationContext.GetChangelogTableName(),
func(dmlEvent *binlog.BinlogDMLEvent) error {
return this.onChangelogHeartbeatEvent(dmlEvent)
}, },
) )
go func() { go func() {
log.Debugf("Beginning streaming") log.Debugf("Beginning streaming")
this.eventsStreamer.StreamEvents(func() bool { return this.canStopStreaming() }) this.eventsStreamer.StreamEvents(func() bool { return this.canStopStreaming() })
@ -281,17 +430,24 @@ func (this *Migrator) initiateApplier() error {
return err return err
} }
this.applier.WriteChangelog("state", string(TablesInPlace)) this.applier.WriteChangelogState(string(TablesInPlace))
this.applier.InitiateHeartbeat() go this.applier.InitiateHeartbeat(heartbeatIntervalMilliseconds)
return nil return nil
} }
func (this *Migrator) iterateChunks() error { func (this *Migrator) iterateChunks() error {
this.migrationContext.RowCopyStartTime = time.Now()
terminateRowIteration := func(err error) error { terminateRowIteration := func(err error) error {
this.rowCopyComplete <- true this.rowCopyComplete <- true
return log.Errore(err) return log.Errore(err)
} }
if this.migrationContext.Noop {
log.Debugf("Noop operation; not really copying data")
return terminateRowIteration(nil)
}
if this.migrationContext.MigrationRangeMinValues == nil {
log.Debugf("No rows found in table. Rowcopy will be implicitly empty")
return terminateRowIteration(nil)
}
for { for {
copyRowsFunc := func() error { copyRowsFunc := func() error {
hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues()
@ -306,7 +462,7 @@ func (this *Migrator) iterateChunks() error {
return terminateRowIteration(err) return terminateRowIteration(err)
} }
atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected) atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected)
this.migrationContext.Iteration++ atomic.AddInt64(&this.migrationContext.Iteration, 1)
return nil return nil
} }
this.copyRowsQueue <- copyRowsFunc this.copyRowsQueue <- copyRowsFunc
@ -315,30 +471,29 @@ func (this *Migrator) iterateChunks() error {
} }
func (this *Migrator) executeWriteFuncs() error { func (this *Migrator) executeWriteFuncs() error {
if this.migrationContext.Noop {
log.Debugf("Noop operation; not really doing writes")
return nil
}
for { for {
throttleMigration( this.throttle(nil)
this.migrationContext,
func() {
log.Debugf("throttling writes")
},
nil,
func() {
log.Debugf("done throttling writes")
},
)
// We give higher priority to event processing, then secondary priority to // We give higher priority to event processing, then secondary priority to
// rowcopy // rowcopy
select { select {
case applyEventFunc := <-this.applyEventsQueue: case applyEventFunc := <-this.applyEventsQueue:
{ {
retryOperation(applyEventFunc, this.migrationContext.MaxRetries()) if err := this.retryOperation(applyEventFunc); err != nil {
return log.Errore(err)
}
} }
default: default:
{ {
select { select {
case copyRowsFunc := <-this.copyRowsQueue: case copyRowsFunc := <-this.copyRowsQueue:
{ {
retryOperation(copyRowsFunc, this.migrationContext.MaxRetries()) if err := this.retryOperation(copyRowsFunc); err != nil {
return log.Errore(err)
}
} }
default: default:
{ {

View File

@ -32,6 +32,20 @@ func EscapeName(name string) string {
return fmt.Sprintf("`%s`", name) return fmt.Sprintf("`%s`", name)
} }
func buildPreparedValues(length int) []string {
values := make([]string, length, length)
for i := 0; i < length; i++ {
values[i] = "?"
}
return values
}
func duplicateNames(names []string) []string {
duplicate := make([]string, len(names), len(names))
copy(duplicate, names)
return duplicate
}
func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) { func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) {
if column == "" { if column == "" {
return "", fmt.Errorf("Empty column in GetValueComparison") return "", fmt.Errorf("Empty column in GetValueComparison")
@ -64,6 +78,22 @@ func BuildEqualsComparison(columns []string, values []string) (result string, er
return result, nil return result, nil
} }
func BuildEqualsPreparedComparison(columns []string) (result string, err error) {
values := buildPreparedValues(len(columns))
return BuildEqualsComparison(columns, values)
}
func BuildSetPreparedClause(columns []string) (result string, err error) {
if len(columns) == 0 {
return "", fmt.Errorf("Got 0 columns in BuildSetPreparedClause")
}
setTokens := []string{}
for _, column := range columns {
setTokens = append(setTokens, fmt.Sprintf("%s=?", EscapeName(column)))
}
return strings.Join(setTokens, ", "), nil
}
func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
if len(columns) == 0 { if len(columns) == 0 {
return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison")
@ -121,10 +151,7 @@ func BuildRangeComparison(columns []string, values []string, args []interface{},
} }
func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
values := make([]string, len(columns), len(columns)) values := buildPreparedValues(len(columns))
for i := range columns {
values[i] = "?"
}
return BuildRangeComparison(columns, values, args, comparisonSign) return BuildRangeComparison(columns, values, args, comparisonSign)
} }
@ -135,6 +162,7 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin
databaseName = EscapeName(databaseName) databaseName = EscapeName(databaseName)
originalTableName = EscapeName(originalTableName) originalTableName = EscapeName(originalTableName)
ghostTableName = EscapeName(ghostTableName) ghostTableName = EscapeName(ghostTableName)
sharedColumns = duplicateNames(sharedColumns)
for i := range sharedColumns { for i := range sharedColumns {
sharedColumns[i] = EscapeName(sharedColumns[i]) sharedColumns[i] = EscapeName(sharedColumns[i])
} }
@ -171,12 +199,8 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin
} }
func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) {
rangeStartValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) rangeStartValues := buildPreparedValues(len(uniqueKeyColumns))
rangeEndValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) rangeEndValues := buildPreparedValues(len(uniqueKeyColumns))
for i := range uniqueKeyColumns {
rangeStartValues[i] = "?"
rangeEndValues[i] = "?"
}
return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable)
} }
@ -198,6 +222,7 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK
} }
explodedArgs = append(explodedArgs, rangeExplodedArgs...) explodedArgs = append(explodedArgs, rangeExplodedArgs...)
uniqueKeyColumns = duplicateNames(uniqueKeyColumns)
uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
for i := range uniqueKeyColumns { for i := range uniqueKeyColumns {
@ -244,6 +269,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni
databaseName = EscapeName(databaseName) databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName) tableName = EscapeName(tableName)
uniqueKeyColumns = duplicateNames(uniqueKeyColumns)
uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
for i := range uniqueKeyColumns { for i := range uniqueKeyColumns {
uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i])
@ -262,3 +288,125 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni
) )
return query, nil return query, nil
} }
func BuildDMLDeleteQuery(databaseName, tableName string, tableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) {
if len(args) != tableColumns.Len() {
return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery")
}
if uniqueKeyColumns.Len() == 0 {
return result, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLDeleteQuery")
}
for _, column := range uniqueKeyColumns.Names {
tableOrdinal := tableColumns.Ordinals[column]
uniqueKeyArgs = append(uniqueKeyArgs, args[tableOrdinal])
}
databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName)
equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names)
if err != nil {
return result, uniqueKeyArgs, err
}
result = fmt.Sprintf(`
delete /* gh-osc %s.%s */
from
%s.%s
where
%s
`, databaseName, tableName,
databaseName, tableName,
equalsComparison,
)
return result, uniqueKeyArgs, nil
}
func BuildDMLInsertQuery(databaseName, tableName string, tableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) {
if len(args) != tableColumns.Len() {
return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery")
}
if !sharedColumns.IsSubsetOf(tableColumns) {
return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery")
}
if sharedColumns.Len() == 0 {
return result, args, fmt.Errorf("No shared columns found in BuildDMLInsertQuery")
}
databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName)
for _, column := range sharedColumns.Names {
tableOrdinal := tableColumns.Ordinals[column]
sharedArgs = append(sharedArgs, args[tableOrdinal])
}
sharedColumnNames := duplicateNames(sharedColumns.Names)
for i := range sharedColumnNames {
sharedColumnNames[i] = EscapeName(sharedColumnNames[i])
}
preparedValues := buildPreparedValues(sharedColumns.Len())
result = fmt.Sprintf(`
replace /* gh-osc %s.%s */ into
%s.%s
(%s)
values
(%s)
`, databaseName, tableName,
databaseName, tableName,
strings.Join(sharedColumnNames, ", "),
strings.Join(preparedValues, ", "),
)
return result, sharedArgs, nil
}
func BuildDMLUpdateQuery(databaseName, tableName string, tableColumns, sharedColumns, uniqueKeyColumns *ColumnList, valueArgs, whereArgs []interface{}) (result string, sharedArgs, uniqueKeyArgs []interface{}, err error) {
if len(valueArgs) != tableColumns.Len() {
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("value args count differs from table column count in BuildDMLUpdateQuery")
}
if len(whereArgs) != tableColumns.Len() {
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("where args count differs from table column count in BuildDMLUpdateQuery")
}
if !sharedColumns.IsSubsetOf(tableColumns) {
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLUpdateQuery")
}
if !uniqueKeyColumns.IsSubsetOf(sharedColumns) {
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("unique key columns is not a subset of shared columns in BuildDMLUpdateQuery")
}
if sharedColumns.Len() == 0 {
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No shared columns found in BuildDMLUpdateQuery")
}
if uniqueKeyColumns.Len() == 0 {
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLUpdateQuery")
}
databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName)
for _, column := range sharedColumns.Names {
tableOrdinal := tableColumns.Ordinals[column]
sharedArgs = append(sharedArgs, valueArgs[tableOrdinal])
}
for _, column := range uniqueKeyColumns.Names {
tableOrdinal := tableColumns.Ordinals[column]
uniqueKeyArgs = append(uniqueKeyArgs, whereArgs[tableOrdinal])
}
sharedColumnNames := duplicateNames(sharedColumns.Names)
for i := range sharedColumnNames {
sharedColumnNames[i] = EscapeName(sharedColumnNames[i])
}
setClause, err := BuildSetPreparedClause(sharedColumnNames)
equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names)
result = fmt.Sprintf(`
update /* gh-osc %s.%s */
%s.%s
set
%s
where
%s
`, databaseName, tableName,
databaseName, tableName,
setClause,
equalsComparison,
)
return result, sharedArgs, uniqueKeyArgs, nil
}

View File

@ -68,6 +68,35 @@ func TestBuildEqualsComparison(t *testing.T) {
} }
} }
func TestBuildEqualsPreparedComparison(t *testing.T) {
{
columns := []string{"c1", "c2"}
comparison, err := BuildEqualsPreparedComparison(columns)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(comparison, "((`c1` = ?) and (`c2` = ?))")
}
}
func TestBuildSetPreparedClause(t *testing.T) {
{
columns := []string{"c1"}
clause, err := BuildSetPreparedClause(columns)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(clause, "`c1`=?")
}
{
columns := []string{"c1", "c2"}
clause, err := BuildSetPreparedClause(columns)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(clause, "`c1`=?, `c2`=?")
}
{
columns := []string{}
_, err := BuildSetPreparedClause(columns)
test.S(t).ExpectNotNil(err)
}
}
func TestBuildRangeComparison(t *testing.T) { func TestBuildRangeComparison(t *testing.T) {
{ {
columns := []string{"c1"} columns := []string{"c1"}
@ -143,7 +172,7 @@ func TestBuildRangeInsertQuery(t *testing.T) {
rangeStartArgs := []interface{}{3} rangeStartArgs := []interface{}{3}
rangeEndArgs := []interface{}{103} rangeEndArgs := []interface{}{103}
query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false)
test.S(t).ExpectNil(err) test.S(t).ExpectNil(err)
expected := ` expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
@ -162,7 +191,7 @@ func TestBuildRangeInsertQuery(t *testing.T) {
rangeStartArgs := []interface{}{3, 17} rangeStartArgs := []interface{}{3, 17}
rangeEndArgs := []interface{}{103, 117} rangeEndArgs := []interface{}{103, 117}
query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false)
test.S(t).ExpectNil(err) test.S(t).ExpectNil(err)
expected := ` expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
@ -186,13 +215,13 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) {
rangeStartArgs := []interface{}{3, 17} rangeStartArgs := []interface{}{3, 17}
rangeEndArgs := []interface{}{103, 117} rangeEndArgs := []interface{}{103, 117}
query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true) query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true, true)
test.S(t).ExpectNil(err) test.S(t).ExpectNil(err)
expected := ` expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
(select id, name, position from mydb.tbl force index (name_position_uidx) (select id, name, position from mydb.tbl force index (name_position_uidx)
where (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?)))) where (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?))))
) lock in share mode )
` `
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117})) test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117}))
@ -202,7 +231,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) {
func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) {
databaseName := "mydb" databaseName := "mydb"
originalTableName := "tbl" originalTableName := "tbl"
chunkSize := 500 var chunkSize int64 = 500
{ {
uniqueKeyColumns := []string{"name", "position"} uniqueKeyColumns := []string{"name", "position"}
rangeStartArgs := []interface{}{3, 17} rangeStartArgs := []interface{}{3, 17}
@ -262,3 +291,191 @@ func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) {
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
} }
} }
func TestBuildDMLDeleteQuery(t *testing.T) {
databaseName := "mydb"
tableName := "tbl"
tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"})
args := []interface{}{3, "testname", "first", 17, 23}
{
uniqueKeyColumns := NewColumnList([]string{"position"})
query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNil(err)
expected := `
delete /* gh-osc mydb.tbl */
from
mydb.tbl
where
((position = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17}))
}
{
uniqueKeyColumns := NewColumnList([]string{"name", "position"})
query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNil(err)
expected := `
delete /* gh-osc mydb.tbl */
from
mydb.tbl
where
((name = ?) and (position = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{"testname", 17}))
}
{
uniqueKeyColumns := NewColumnList([]string{"position", "name"})
query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNil(err)
expected := `
delete /* gh-osc mydb.tbl */
from
mydb.tbl
where
((position = ?) and (name = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"}))
}
{
uniqueKeyColumns := NewColumnList([]string{"position", "name"})
args := []interface{}{"first", 17}
_, _, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNotNil(err)
}
}
func TestBuildDMLInsertQuery(t *testing.T) {
databaseName := "mydb"
tableName := "tbl"
tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"})
args := []interface{}{3, "testname", "first", 17, 23}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNil(err)
expected := `
replace /* gh-osc mydb.tbl */
into mydb.tbl
(id, name, position, age)
values
(?, ?, ?, ?)
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23}))
}
{
sharedColumns := NewColumnList([]string{"position", "name", "age", "id"})
query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNil(err)
expected := `
replace /* gh-osc mydb.tbl */
into mydb.tbl
(position, name, age, id)
values
(?, ?, ?, ?)
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{17, "testname", 23, 3}))
}
{
sharedColumns := NewColumnList([]string{"position", "name", "surprise", "id"})
_, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNotNil(err)
}
{
sharedColumns := NewColumnList([]string{})
_, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNotNil(err)
}
}
func TestBuildDMLUpdateQuery(t *testing.T) {
databaseName := "mydb"
tableName := "tbl"
tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"})
valueArgs := []interface{}{3, "testname", "newval", 17, 23}
whereArgs := []interface{}{3, "testname", "findme", 17, 56}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
uniqueKeyColumns := NewColumnList([]string{"position"})
query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs)
test.S(t).ExpectNil(err)
expected := `
update /* gh-osc mydb.tbl */
mydb.tbl
set id=?, name=?, position=?, age=?
where
((position = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23}))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17}))
}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
uniqueKeyColumns := NewColumnList([]string{"position", "name"})
query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs)
test.S(t).ExpectNil(err)
expected := `
update /* gh-osc mydb.tbl */
mydb.tbl
set id=?, name=?, position=?, age=?
where
((position = ?) and (name = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23}))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"}))
}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
uniqueKeyColumns := NewColumnList([]string{"age"})
query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs)
test.S(t).ExpectNil(err)
expected := `
update /* gh-osc mydb.tbl */
mydb.tbl
set id=?, name=?, position=?, age=?
where
((age = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23}))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56}))
}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
uniqueKeyColumns := NewColumnList([]string{"age", "position", "id", "name"})
query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs)
test.S(t).ExpectNil(err)
expected := `
update /* gh-osc mydb.tbl */
mydb.tbl
set id=?, name=?, position=?, age=?
where
((age = ?) and (position = ?) and (id = ?) and (name = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23}))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56, 17, 3, "testname"}))
}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
uniqueKeyColumns := NewColumnList([]string{"age", "surprise"})
_, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs)
test.S(t).ExpectNotNil(err)
}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
uniqueKeyColumns := NewColumnList([]string{})
_, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs)
test.S(t).ExpectNotNil(err)
}
}

View File

@ -11,21 +11,62 @@ import (
"strings" "strings"
) )
// ColumnsMap maps a column onto its ordinal position
type ColumnsMap map[string]int
func NewColumnsMap(orderedNames []string) ColumnsMap {
columnsMap := make(map[string]int)
for i, column := range orderedNames {
columnsMap[column] = i
}
return ColumnsMap(columnsMap)
}
// ColumnList makes for a named list of columns // ColumnList makes for a named list of columns
type ColumnList []string type ColumnList struct {
Names []string
Ordinals ColumnsMap
}
// NewColumnList creates an object given ordered list of column names
func NewColumnList(names []string) *ColumnList {
result := &ColumnList{
Names: names,
}
result.Ordinals = NewColumnsMap(result.Names)
return result
}
// ParseColumnList parses a comma delimited list of column names // ParseColumnList parses a comma delimited list of column names
func ParseColumnList(columns string) *ColumnList { func ParseColumnList(columns string) *ColumnList {
result := ColumnList(strings.Split(columns, ",")) result := &ColumnList{
return &result Names: strings.Split(columns, ","),
}
result.Ordinals = NewColumnsMap(result.Names)
return result
} }
func (this *ColumnList) String() string { func (this *ColumnList) String() string {
return strings.Join(*this, ",") return strings.Join(this.Names, ",")
} }
func (this *ColumnList) Equals(other *ColumnList) bool { func (this *ColumnList) Equals(other *ColumnList) bool {
return reflect.DeepEqual(*this, *other) return reflect.DeepEqual(this.Names, other.Names)
}
// IsSubsetOf returns 'true' when column names of this list are a subset of
// another list, in arbitrary order (order agnostic)
func (this *ColumnList) IsSubsetOf(other *ColumnList) bool {
for _, column := range this.Names {
if _, exists := other.Ordinals[column]; !exists {
return false
}
}
return true
}
func (this *ColumnList) Len() int {
return len(this.Names)
} }
// UniqueKey is the combination of a key's name and columns // UniqueKey is the combination of a key's name and columns
@ -40,6 +81,10 @@ func (this *UniqueKey) IsPrimary() bool {
return this.Name == "PRIMARY" return this.Name == "PRIMARY"
} }
func (this *UniqueKey) Len() int {
return this.Columns.Len()
}
func (this *UniqueKey) String() string { func (this *UniqueKey) String() string {
return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable) return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable)
} }

29
go/sql/types_test.go Normal file
View File

@ -0,0 +1,29 @@
/*
Copyright 2016 GitHub Inc.
See https://github.com/github/gh-osc/blob/master/LICENSE
*/
package sql
import (
"testing"
"github.com/outbrain/golib/log"
test "github.com/outbrain/golib/tests"
"reflect"
)
func init() {
log.SetLevel(log.ERROR)
}
func TestParseColumnList(t *testing.T) {
names := "id,category,max_len"
columnList := ParseColumnList(names)
test.S(t).ExpectEquals(columnList.Len(), 3)
test.S(t).ExpectTrue(reflect.DeepEqual(columnList.Names, []string{"id", "category", "max_len"}))
test.S(t).ExpectEquals(columnList.Ordinals["id"], 0)
test.S(t).ExpectEquals(columnList.Ordinals["category"], 1)
test.S(t).ExpectEquals(columnList.Ordinals["max_len"], 2)
}