- Throttling-check is now an async routine running once per second

- Throttling variables protected by mutex
- Added `--throttle-additional-flag-file`: `operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations`
- ColumnList is not a `struct` which contains ordinal mapping
- More implicit write changelog + audit changelog
- builder now builds `DELETE` and `INSERT` queries from data it will eventually get from DML event
- Sanity check for binlog_row_image
- Restarting replication to be sure binlog settings apply
- Prepare for accepting `SIGHUP` (reloading configuration)
This commit is contained in:
Shlomi Noach 2016-04-11 17:27:16 +02:00
parent 80163b35b6
commit 04525887f3
8 changed files with 434 additions and 130 deletions

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -50,18 +51,19 @@ type MigrationContext struct {
CurrentLag int64 CurrentLag int64
MaxLagMillisecondsThrottleThreshold int64 MaxLagMillisecondsThrottleThreshold int64
ThrottleFlagFile string ThrottleFlagFile string
ThrottleAdditionalFlagFile string
TotalRowsCopied int64 TotalRowsCopied int64
isThrottled int64 isThrottled bool
ThrottleReason string throttleReason string
throttleMutex *sync.Mutex
MaxLoad map[string]int64 MaxLoad map[string]int64
OriginalTableColumns sql.ColumnList OriginalTableColumns *sql.ColumnList
OriginalTableColumnsMap sql.ColumnsMap
OriginalTableUniqueKeys [](*sql.UniqueKey) OriginalTableUniqueKeys [](*sql.UniqueKey)
GhostTableColumns sql.ColumnList GhostTableColumns *sql.ColumnList
GhostTableUniqueKeys [](*sql.UniqueKey) GhostTableUniqueKeys [](*sql.UniqueKey)
UniqueKey *sql.UniqueKey UniqueKey *sql.UniqueKey
SharedColumns sql.ColumnList SharedColumns *sql.ColumnList
MigrationRangeMinValues *sql.ColumnValues MigrationRangeMinValues *sql.ColumnValues
MigrationRangeMaxValues *sql.ColumnValues MigrationRangeMaxValues *sql.ColumnValues
Iteration int64 Iteration int64
@ -83,7 +85,8 @@ func newMigrationContext() *MigrationContext {
InspectorConnectionConfig: mysql.NewConnectionConfig(), InspectorConnectionConfig: mysql.NewConnectionConfig(),
MasterConnectionConfig: mysql.NewConnectionConfig(), MasterConnectionConfig: mysql.NewConnectionConfig(),
MaxLagMillisecondsThrottleThreshold: 1000, MaxLagMillisecondsThrottleThreshold: 1000,
MaxLoad: make(map[string]int64), MaxLoad: make(map[string]int64),
throttleMutex: &sync.Mutex{},
} }
} }
@ -97,6 +100,11 @@ func (this *MigrationContext) GetGhostTableName() string {
return fmt.Sprintf("_%s_New", this.OriginalTableName) return fmt.Sprintf("_%s_New", this.OriginalTableName)
} }
// GetOldTableName generates the name of the "old" table, into which the original table is renamed.
func (this *MigrationContext) GetOldTableName() string {
return fmt.Sprintf("_%s_Old", this.OriginalTableName)
}
// GetChangelogTableName generates the name of changelog table, based on original table name // GetChangelogTableName generates the name of changelog table, based on original table name
func (this *MigrationContext) GetChangelogTableName() string { func (this *MigrationContext) GetChangelogTableName() string {
return fmt.Sprintf("_%s_OSC", this.OriginalTableName) return fmt.Sprintf("_%s_OSC", this.OriginalTableName)
@ -157,16 +165,17 @@ func (this *MigrationContext) GetIteration() int64 {
return atomic.LoadInt64(&this.Iteration) return atomic.LoadInt64(&this.Iteration)
} }
func (this *MigrationContext) SetThrottled(throttle bool) { func (this *MigrationContext) SetThrottled(throttle bool, reason string) {
if throttle { this.throttleMutex.Lock()
atomic.StoreInt64(&this.isThrottled, 1) defer func() { this.throttleMutex.Unlock() }()
} else { this.isThrottled = throttle
atomic.StoreInt64(&this.isThrottled, 0) this.throttleReason = reason
}
} }
func (this *MigrationContext) IsThrottled() bool { func (this *MigrationContext) IsThrottled() (bool, string) {
return atomic.LoadInt64(&this.isThrottled) != 0 this.throttleMutex.Lock()
defer func() { this.throttleMutex.Unlock() }()
return this.isThrottled, this.throttleReason
} }
func (this *MigrationContext) ReadMaxLoad(maxLoadList string) error { func (this *MigrationContext) ReadMaxLoad(maxLoadList string) error {

View File

@ -38,7 +38,8 @@ func main() {
migrationContext.ChunkSize = 100000 migrationContext.ChunkSize = 100000
} }
flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1000, "replication lag at which to throttle operation") flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1000, "replication lag at which to throttle operation")
flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists") flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered")
flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations")
maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'") maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'")
quiet := flag.Bool("quiet", false, "quiet") quiet := flag.Bool("quiet", false, "quiet")
verbose := flag.Bool("verbose", false, "verbose") verbose := flag.Bool("verbose", false, "verbose")

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/github/gh-osc/go/base" "github.com/github/gh-osc/go/base"
"github.com/github/gh-osc/go/binlog"
"github.com/github/gh-osc/go/mysql" "github.com/github/gh-osc/go/mysql"
"github.com/github/gh-osc/go/sql" "github.com/github/gh-osc/go/sql"
@ -63,7 +64,7 @@ func (this *Applier) validateConnection() error {
return nil return nil
} }
// CreateGhostTable creates the ghost table on the master // CreateGhostTable creates the ghost table on the applier host
func (this *Applier) CreateGhostTable() error { func (this *Applier) CreateGhostTable() error {
query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`, query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
@ -82,7 +83,7 @@ func (this *Applier) CreateGhostTable() error {
return nil return nil
} }
// CreateGhostTable creates the ghost table on the master // AlterGhost applies `alter` statement on ghost table
func (this *Applier) AlterGhost() error { func (this *Applier) AlterGhost() error {
query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`, query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
@ -101,7 +102,7 @@ func (this *Applier) AlterGhost() error {
return nil return nil
} }
// CreateChangelogTable creates the changelog table on the master // CreateChangelogTable creates the changelog table on the applier host
func (this *Applier) CreateChangelogTable() error { func (this *Applier) CreateChangelogTable() error {
query := fmt.Sprintf(`create /* gh-osc */ table %s.%s ( query := fmt.Sprintf(`create /* gh-osc */ table %s.%s (
id bigint auto_increment, id bigint auto_increment,
@ -110,7 +111,7 @@ func (this *Applier) CreateChangelogTable() error {
value varchar(255) charset ascii not null, value varchar(255) charset ascii not null,
primary key(id), primary key(id),
unique key hint_uidx(hint) unique key hint_uidx(hint)
) auto_increment=2 ) auto_increment=256
`, `,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()), sql.EscapeName(this.migrationContext.GetChangelogTableName()),
@ -126,23 +127,33 @@ func (this *Applier) CreateChangelogTable() error {
return nil return nil
} }
// DropChangelogTable drops the changelog table on the master // dropTable drops a given table on the applied host
func (this *Applier) DropChangelogTable() error { func (this *Applier) dropTable(tableName string) error {
query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`, query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`,
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()), sql.EscapeName(tableName),
) )
log.Infof("Droppping changelog table %s.%s", log.Infof("Droppping table %s.%s",
sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.DatabaseName),
sql.EscapeName(this.migrationContext.GetChangelogTableName()), sql.EscapeName(tableName),
) )
if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil {
return err return err
} }
log.Infof("Changelog table dropped") log.Infof("Table dropped")
return nil return nil
} }
// DropChangelogTable drops the changelog table on the applier host
func (this *Applier) DropChangelogTable() error {
return this.dropTable(this.migrationContext.GetChangelogTableName())
}
// DropGhostTable drops the ghost table on the applier host
func (this *Applier) DropGhostTable() error {
return this.dropTable(this.migrationContext.GetGhostTableName())
}
// WriteChangelog writes a value to the changelog table. // WriteChangelog writes a value to the changelog table.
// It returns the hint as given, for convenience // It returns the hint as given, for convenience
func (this *Applier) WriteChangelog(hint, value string) (string, error) { func (this *Applier) WriteChangelog(hint, value string) (string, error) {
@ -162,12 +173,15 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) {
return hint, err return hint, err
} }
func (this *Applier) WriteChangelogState(value string) (string, error) { func (this *Applier) WriteAndLogChangelog(hint, value string) (string, error) {
hint := "state"
this.WriteChangelog(hint, value) this.WriteChangelog(hint, value)
return this.WriteChangelog(fmt.Sprintf("%s at %d", hint, time.Now().UnixNano()), value) return this.WriteChangelog(fmt.Sprintf("%s at %d", hint, time.Now().UnixNano()), value)
} }
func (this *Applier) WriteChangelogState(value string) (string, error) {
return this.WriteAndLogChangelog("state", value)
}
// InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table.
// This is done asynchronously // This is done asynchronously
func (this *Applier) InitiateHeartbeat() { func (this *Applier) InitiateHeartbeat() {
@ -213,7 +227,7 @@ func (this *Applier) InitiateHeartbeat() {
// ReadMigrationMinValues // ReadMigrationMinValues
func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error {
log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) log.Debugf("Reading migration range according to key: %s", uniqueKey.Name)
query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names)
if err != nil { if err != nil {
return err return err
} }
@ -222,7 +236,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error {
return err return err
} }
for rows.Next() { for rows.Next() {
this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(len(uniqueKey.Columns)) this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(uniqueKey.Len())
if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil { if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil {
return err return err
} }
@ -234,7 +248,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error {
// ReadMigrationMinValues // ReadMigrationMinValues
func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error {
log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) log.Debugf("Reading migration range according to key: %s", uniqueKey.Name)
query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names)
if err != nil { if err != nil {
return err return err
} }
@ -243,7 +257,7 @@ func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error {
return err return err
} }
for rows.Next() { for rows.Next() {
this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(len(uniqueKey.Columns)) this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(uniqueKey.Len())
if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil { if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil {
return err return err
} }
@ -272,12 +286,12 @@ func (this *Applier) __unused_IterationIsComplete() (bool, error) {
return false, nil return false, nil
} }
args := sqlutils.Args() args := sqlutils.Args()
compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign)
if err != nil { if err != nil {
return false, err return false, err
} }
args = append(args, explodedArgs...) args = append(args, explodedArgs...)
compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -317,7 +331,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo
query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery(
this.migrationContext.DatabaseName, this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableName,
this.migrationContext.UniqueKey.Columns, this.migrationContext.UniqueKey.Columns.Names,
this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(),
this.migrationContext.MigrationRangeMaxValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(),
this.migrationContext.ChunkSize, this.migrationContext.ChunkSize,
@ -330,7 +344,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo
if err != nil { if err != nil {
return hasFurtherRange, err return hasFurtherRange, err
} }
iterationRangeMaxValues := sql.NewColumnValues(len(this.migrationContext.UniqueKey.Columns)) iterationRangeMaxValues := sql.NewColumnValues(this.migrationContext.UniqueKey.Len())
for rows.Next() { for rows.Next() {
if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil { if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil {
return hasFurtherRange, err return hasFurtherRange, err
@ -360,9 +374,9 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected
this.migrationContext.DatabaseName, this.migrationContext.DatabaseName,
this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableName,
this.migrationContext.GetGhostTableName(), this.migrationContext.GetGhostTableName(),
this.migrationContext.SharedColumns, this.migrationContext.SharedColumns.Names,
this.migrationContext.UniqueKey.Name, this.migrationContext.UniqueKey.Name,
this.migrationContext.UniqueKey.Columns, this.migrationContext.UniqueKey.Columns.Names,
this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(),
this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(),
this.migrationContext.GetIteration() == 0, this.migrationContext.GetIteration() == 0,
@ -422,3 +436,12 @@ func (this *Applier) ShowStatusVariable(variableName string) (result int64, err
} }
return result, nil return result, nil
} }
func (this *Applier) BuildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (result string, err error) {
switch dmlEvent.DML {
case binlog.DeleteDML:
{
}
}
return result, err
}

View File

@ -44,6 +44,9 @@ func (this *Inspector) InitDBConnections() (err error) {
if err := this.validateGrants(); err != nil { if err := this.validateGrants(); err != nil {
return err return err
} }
if err := this.restartReplication(); err != nil {
return err
}
if err := this.validateBinlogs(); err != nil { if err := this.validateBinlogs(); err != nil {
return err return err
} }
@ -69,7 +72,7 @@ func (this *Inspector) ValidateOriginalTable() (err error) {
return nil return nil
} }
func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) {
uniqueKeys, err = this.getCandidateUniqueKeys(tableName) uniqueKeys, err = this.getCandidateUniqueKeys(tableName)
if err != nil { if err != nil {
return columns, uniqueKeys, err return columns, uniqueKeys, err
@ -90,7 +93,6 @@ func (this *Inspector) InspectOriginalTable() (err error) {
if err == nil { if err == nil {
return err return err
} }
this.migrationContext.OriginalTableColumnsMap = sql.NewColumnsMap(this.migrationContext.OriginalTableColumns)
return nil return nil
} }
@ -108,6 +110,11 @@ func (this *Inspector) InspectOriginalAndGhostTables() (err error) {
} }
this.migrationContext.UniqueKey = sharedUniqueKeys[0] this.migrationContext.UniqueKey = sharedUniqueKeys[0]
log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name) log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name)
if !this.migrationContext.UniqueKey.IsPrimary() {
if this.migrationContext.OriginalBinlogRowImage != "full" {
return fmt.Errorf("binlog_row_image is '%s' and chosen key is %s, which is not the primary key. This operation cannot proceed. You may `set global binlog_row_image='full'` and try again")
}
}
this.migrationContext.SharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns) this.migrationContext.SharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns)
log.Infof("Shared columns are %s", this.migrationContext.SharedColumns) log.Infof("Shared columns are %s", this.migrationContext.SharedColumns)
@ -171,6 +178,26 @@ func (this *Inspector) validateGrants() error {
return log.Errorf("User has insufficient privileges for migration.") return log.Errorf("User has insufficient privileges for migration.")
} }
// restartReplication is required so that we are _certain_ the binlog format and
// row image settings have actually been applied to the replication thread.
// It is entriely possible, for example, that the replication is using 'STATEMENT'
// binlog format even as the variable says 'ROW'
func (this *Inspector) restartReplication() error {
log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port)
var stopError, startError error
_, stopError = sqlutils.ExecNoPrepare(this.db, `stop slave`)
_, startError = sqlutils.ExecNoPrepare(this.db, `start slave`)
if stopError != nil {
return stopError
}
if startError != nil {
return startError
}
log.Debugf("Replication restarted")
return nil
}
// validateBinlogs checks that binary log configuration is good to go // validateBinlogs checks that binary log configuration is good to go
func (this *Inspector) validateBinlogs() error { func (this *Inspector) validateBinlogs() error {
query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format` query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format`
@ -299,27 +326,28 @@ func (this *Inspector) countTableRows() error {
return nil return nil
} }
func (this *Inspector) getTableColumns(databaseName, tableName string) (columns sql.ColumnList, err error) { func (this *Inspector) getTableColumns(databaseName, tableName string) (*sql.ColumnList, error) {
query := fmt.Sprintf(` query := fmt.Sprintf(`
show columns from %s.%s show columns from %s.%s
`, `,
sql.EscapeName(databaseName), sql.EscapeName(databaseName),
sql.EscapeName(tableName), sql.EscapeName(tableName),
) )
err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { columnNames := []string{}
columns = append(columns, rowMap.GetString("Field")) err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error {
columnNames = append(columnNames, rowMap.GetString("Field"))
return nil return nil
}) })
if err != nil { if err != nil {
return columns, err return nil, err
} }
if len(columns) == 0 { if len(columnNames) == 0 {
return columns, log.Errorf("Found 0 columns on %s.%s. Bailing out", return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out",
sql.EscapeName(databaseName), sql.EscapeName(databaseName),
sql.EscapeName(tableName), sql.EscapeName(tableName),
) )
} }
return columns, nil return sql.NewColumnList(columnNames), nil
} }
// getCandidateUniqueKeys investigates a table and returns the list of unique keys // getCandidateUniqueKeys investigates a table and returns the list of unique keys
@ -412,17 +440,18 @@ func (this *Inspector) getSharedUniqueKeys(originalUniqueKeys, ghostUniqueKeys [
} }
// getSharedColumns returns the intersection of two lists of columns in same order as the first list // getSharedColumns returns the intersection of two lists of columns in same order as the first list
func (this *Inspector) getSharedColumns(originalColumns, ghostColumns sql.ColumnList) (sharedColumns sql.ColumnList) { func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.ColumnList) *sql.ColumnList {
columnsInGhost := make(map[string]bool) columnsInGhost := make(map[string]bool)
for _, ghostColumn := range ghostColumns { for _, ghostColumn := range ghostColumns.Names {
columnsInGhost[ghostColumn] = true columnsInGhost[ghostColumn] = true
} }
for _, originalColumn := range originalColumns { sharedColumnNames := []string{}
for _, originalColumn := range originalColumns.Names {
if columnsInGhost[originalColumn] { if columnsInGhost[originalColumn] {
sharedColumns = append(sharedColumns, originalColumn) sharedColumnNames = append(sharedColumnNames, originalColumn)
} }
} }
return sharedColumns return sql.NewColumnList(sharedColumnNames)
} }
func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) {

View File

@ -8,8 +8,10 @@ package logic
import ( import (
"fmt" "fmt"
"os" "os"
"os/signal"
"regexp" "regexp"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"github.com/github/gh-osc/go/base" "github.com/github/gh-osc/go/base"
@ -74,6 +76,21 @@ func prettifyDurationOutput(d time.Duration) string {
return result return result
} }
// acceptSignals registers for OS signals
func (this *Migrator) acceptSignals() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGHUP)
go func() {
for sig := range c {
switch sig {
case syscall.SIGHUP:
log.Debugf("Received SIGHUP. Reloading configuration")
}
}
}()
}
func (this *Migrator) shouldThrottle() (result bool, reason string) { func (this *Migrator) shouldThrottle() (result bool, reason string) {
lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) lag := atomic.LoadInt64(&this.migrationContext.CurrentLag)
@ -82,7 +99,13 @@ func (this *Migrator) shouldThrottle() (result bool, reason string) {
} }
if this.migrationContext.ThrottleFlagFile != "" { if this.migrationContext.ThrottleFlagFile != "" {
if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil { if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil {
//Throttle file defined and exists! // Throttle file defined and exists!
return true, "flag-file"
}
}
if this.migrationContext.ThrottleAdditionalFlagFile != "" {
if _, err := os.Stat(this.migrationContext.ThrottleAdditionalFlagFile); err == nil {
// 2nd Throttle file defined and exists!
return true, "flag-file" return true, "flag-file"
} }
} }
@ -100,37 +123,43 @@ func (this *Migrator) shouldThrottle() (result bool, reason string) {
return false, "" return false, ""
} }
func (this *Migrator) initiateThrottler() error {
throttlerTick := time.Tick(1 * time.Second)
throttlerFunction := func() {
alreadyThrottling, currentReason := this.migrationContext.IsThrottled()
shouldThrottle, throttleReason := this.shouldThrottle()
if shouldThrottle && !alreadyThrottling {
// New throttling
this.applier.WriteAndLogChangelog("throttle", throttleReason)
} else if shouldThrottle && alreadyThrottling && (currentReason != throttleReason) {
// Change of reason
this.applier.WriteAndLogChangelog("throttle", throttleReason)
} else if alreadyThrottling && !shouldThrottle {
// End of throttling
this.applier.WriteAndLogChangelog("throttle", "done throttling")
}
this.migrationContext.SetThrottled(shouldThrottle, throttleReason)
}
throttlerFunction()
for range throttlerTick {
throttlerFunction()
}
return nil
}
// throttle initiates a throttling event, if need be, updates the Context and // throttle initiates a throttling event, if need be, updates the Context and
// calls callback functions, if any // calls callback functions, if any
func (this *Migrator) throttle( func (this *Migrator) throttle(onThrottled func()) {
onStartThrottling func(),
onContinuousThrottling func(),
onEndThrottling func(),
) {
hasThrottledYet := false
for { for {
shouldThrottle, reason := this.shouldThrottle() if shouldThrottle, _ := this.migrationContext.IsThrottled(); !shouldThrottle {
if !shouldThrottle { return
break
} }
this.migrationContext.ThrottleReason = reason if onThrottled != nil {
if !hasThrottledYet { onThrottled()
hasThrottledYet = true
if onStartThrottling != nil {
onStartThrottling()
}
this.migrationContext.SetThrottled(true)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
if onContinuousThrottling != nil {
onContinuousThrottling()
}
}
if hasThrottledYet {
if onEndThrottling != nil {
onEndThrottling()
}
this.migrationContext.SetThrottled(false)
} }
} }
@ -239,6 +268,7 @@ func (this *Migrator) Migrate() (err error) {
if err := this.applier.ReadMigrationRangeValues(); err != nil { if err := this.applier.ReadMigrationRangeValues(); err != nil {
return err return err
} }
go this.initiateThrottler()
go this.executeWriteFuncs() go this.executeWriteFuncs()
go this.iterateChunks() go this.iterateChunks()
this.migrationContext.RowCopyStartTime = time.Now() this.migrationContext.RowCopyStartTime = time.Now()
@ -249,15 +279,9 @@ func (this *Migrator) Migrate() (err error) {
log.Debugf("Row copy complete") log.Debugf("Row copy complete")
this.printStatus() this.printStatus()
this.throttle( this.throttle(func() {
func() { log.Debugf("throttling on LOCK TABLES")
log.Debugf("throttling before LOCK TABLES") })
},
nil,
func() {
log.Debugf("done throttling")
},
)
// TODO retries!! // TODO retries!!
this.applier.LockTables() this.applier.LockTables()
this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed)) this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed))
@ -304,8 +328,8 @@ func (this *Migrator) printStatus() {
} }
eta := "N/A" eta := "N/A"
if this.migrationContext.IsThrottled() { if isThrottled, throttleReason := this.migrationContext.IsThrottled(); isThrottled {
eta = fmt.Sprintf("throttled, %s", this.migrationContext.ThrottleReason) eta = fmt.Sprintf("throttled, %s", throttleReason)
} }
status := fmt.Sprintf("Copy: %d/%d %.1f%%; Backlog: %d/%d; Elapsed: %+v(copy), %+v(total); ETA: %s", status := fmt.Sprintf("Copy: %d/%d %.1f%%; Backlog: %d/%d; Elapsed: %+v(copy), %+v(total); ETA: %s",
totalRowsCopied, rowsEstimate, progressPct, totalRowsCopied, rowsEstimate, progressPct,
@ -399,14 +423,8 @@ func (this *Migrator) iterateChunks() error {
} }
func (this *Migrator) executeWriteFuncs() error { func (this *Migrator) executeWriteFuncs() error {
onStartThrottling := func() {
log.Debugf("throttling writes")
}
onEndThrottling := func() {
log.Debugf("done throttling writes")
}
for { for {
this.throttle(onStartThrottling, nil, onEndThrottling) this.throttle(nil)
// We give higher priority to event processing, then secondary priority to // We give higher priority to event processing, then secondary priority to
// rowcopy // rowcopy
select { select {

View File

@ -32,6 +32,20 @@ func EscapeName(name string) string {
return fmt.Sprintf("`%s`", name) return fmt.Sprintf("`%s`", name)
} }
func buildPreparedValues(length int) []string {
values := make([]string, length, length)
for i := 0; i < length; i++ {
values[i] = "?"
}
return values
}
func duplicateNames(names []string) []string {
duplicate := make([]string, len(names), len(names))
copy(duplicate, names)
return duplicate
}
func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) { func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) {
if column == "" { if column == "" {
return "", fmt.Errorf("Empty column in GetValueComparison") return "", fmt.Errorf("Empty column in GetValueComparison")
@ -64,6 +78,11 @@ func BuildEqualsComparison(columns []string, values []string) (result string, er
return result, nil return result, nil
} }
func BuildEqualsPreparedComparison(columns []string) (result string, err error) {
values := buildPreparedValues(len(columns))
return BuildEqualsComparison(columns, values)
}
func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
if len(columns) == 0 { if len(columns) == 0 {
return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison")
@ -121,10 +140,7 @@ func BuildRangeComparison(columns []string, values []string, args []interface{},
} }
func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
values := make([]string, len(columns), len(columns)) values := buildPreparedValues(len(columns))
for i := range columns {
values[i] = "?"
}
return BuildRangeComparison(columns, values, args, comparisonSign) return BuildRangeComparison(columns, values, args, comparisonSign)
} }
@ -135,6 +151,7 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin
databaseName = EscapeName(databaseName) databaseName = EscapeName(databaseName)
originalTableName = EscapeName(originalTableName) originalTableName = EscapeName(originalTableName)
ghostTableName = EscapeName(ghostTableName) ghostTableName = EscapeName(ghostTableName)
sharedColumns = duplicateNames(sharedColumns)
for i := range sharedColumns { for i := range sharedColumns {
sharedColumns[i] = EscapeName(sharedColumns[i]) sharedColumns[i] = EscapeName(sharedColumns[i])
} }
@ -171,12 +188,8 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin
} }
func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) {
rangeStartValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) rangeStartValues := buildPreparedValues(len(uniqueKeyColumns))
rangeEndValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) rangeEndValues := buildPreparedValues(len(uniqueKeyColumns))
for i := range uniqueKeyColumns {
rangeStartValues[i] = "?"
rangeEndValues[i] = "?"
}
return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable)
} }
@ -198,6 +211,7 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK
} }
explodedArgs = append(explodedArgs, rangeExplodedArgs...) explodedArgs = append(explodedArgs, rangeExplodedArgs...)
uniqueKeyColumns = duplicateNames(uniqueKeyColumns)
uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
for i := range uniqueKeyColumns { for i := range uniqueKeyColumns {
@ -244,6 +258,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni
databaseName = EscapeName(databaseName) databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName) tableName = EscapeName(tableName)
uniqueKeyColumns = duplicateNames(uniqueKeyColumns)
uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns))
for i := range uniqueKeyColumns { for i := range uniqueKeyColumns {
uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i])
@ -262,3 +277,65 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni
) )
return query, nil return query, nil
} }
func BuildDMLDeleteQuery(databaseName, tableName string, originalTableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) {
if len(args) != originalTableColumns.Len() {
return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery")
}
for _, column := range uniqueKeyColumns.Names {
tableOrdinal := originalTableColumns.Ordinals[column]
uniqueKeyArgs = append(uniqueKeyArgs, args[tableOrdinal])
}
databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName)
equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names)
result = fmt.Sprintf(`
delete /* gh-osc %s.%s */
from
%s.%s
where
%s
`, databaseName, tableName,
databaseName, tableName,
equalsComparison,
)
return result, uniqueKeyArgs, err
}
func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) {
if len(args) != originalTableColumns.Len() {
return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery")
}
if !sharedColumns.IsSubsetOf(originalTableColumns) {
return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery")
}
if sharedColumns.Len() == 0 {
return result, args, fmt.Errorf("No shared columns found in BuildDMLInsertQuery")
}
databaseName = EscapeName(databaseName)
tableName = EscapeName(tableName)
for _, column := range sharedColumns.Names {
tableOrdinal := originalTableColumns.Ordinals[column]
sharedArgs = append(sharedArgs, args[tableOrdinal])
}
sharedColumnNames := duplicateNames(sharedColumns.Names)
for i := range sharedColumnNames {
sharedColumnNames[i] = EscapeName(sharedColumnNames[i])
}
preparedValues := buildPreparedValues(sharedColumns.Len())
result = fmt.Sprintf(`
replace /* gh-osc %s.%s */ into
%s.%s
(%s)
values
(%s)
`, databaseName, tableName,
databaseName, tableName,
strings.Join(sharedColumnNames, ", "),
strings.Join(preparedValues, ", "),
)
return result, sharedArgs, err
}

View File

@ -68,6 +68,15 @@ func TestBuildEqualsComparison(t *testing.T) {
} }
} }
func TestBuildEqualsPreparedComparison(t *testing.T) {
{
columns := []string{"c1", "c2"}
comparison, err := BuildEqualsPreparedComparison(columns)
test.S(t).ExpectNil(err)
test.S(t).ExpectEquals(comparison, "((`c1` = ?) and (`c2` = ?))")
}
}
func TestBuildRangeComparison(t *testing.T) { func TestBuildRangeComparison(t *testing.T) {
{ {
columns := []string{"c1"} columns := []string{"c1"}
@ -143,7 +152,7 @@ func TestBuildRangeInsertQuery(t *testing.T) {
rangeStartArgs := []interface{}{3} rangeStartArgs := []interface{}{3}
rangeEndArgs := []interface{}{103} rangeEndArgs := []interface{}{103}
query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false)
test.S(t).ExpectNil(err) test.S(t).ExpectNil(err)
expected := ` expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
@ -162,7 +171,7 @@ func TestBuildRangeInsertQuery(t *testing.T) {
rangeStartArgs := []interface{}{3, 17} rangeStartArgs := []interface{}{3, 17}
rangeEndArgs := []interface{}{103, 117} rangeEndArgs := []interface{}{103, 117}
query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false)
test.S(t).ExpectNil(err) test.S(t).ExpectNil(err)
expected := ` expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
@ -186,13 +195,13 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) {
rangeStartArgs := []interface{}{3, 17} rangeStartArgs := []interface{}{3, 17}
rangeEndArgs := []interface{}{103, 117} rangeEndArgs := []interface{}{103, 117}
query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true) query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true, true)
test.S(t).ExpectNil(err) test.S(t).ExpectNil(err)
expected := ` expected := `
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
(select id, name, position from mydb.tbl force index (name_position_uidx) (select id, name, position from mydb.tbl force index (name_position_uidx)
where (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?)))) where (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?))))
) lock in share mode )
` `
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117})) test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117}))
@ -202,7 +211,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) {
func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) {
databaseName := "mydb" databaseName := "mydb"
originalTableName := "tbl" originalTableName := "tbl"
chunkSize := 500 var chunkSize int64 = 500
{ {
uniqueKeyColumns := []string{"name", "position"} uniqueKeyColumns := []string{"name", "position"}
rangeStartArgs := []interface{}{3, 17} rangeStartArgs := []interface{}{3, 17}
@ -262,3 +271,107 @@ func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) {
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
} }
} }
func TestBuildDMLDeleteQuery(t *testing.T) {
databaseName := "mydb"
tableName := "tbl"
tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"})
args := []interface{}{3, "testname", "first", 17, 23}
{
uniqueKeyColumns := NewColumnList([]string{"position"})
query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNil(err)
expected := `
delete /* gh-osc mydb.tbl */
from
mydb.tbl
where
((position = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17}))
}
{
uniqueKeyColumns := NewColumnList([]string{"name", "position"})
query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNil(err)
expected := `
delete /* gh-osc mydb.tbl */
from
mydb.tbl
where
((name = ?) and (position = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{"testname", 17}))
}
{
uniqueKeyColumns := NewColumnList([]string{"position", "name"})
query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNil(err)
expected := `
delete /* gh-osc mydb.tbl */
from
mydb.tbl
where
((position = ?) and (name = ?))
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"}))
}
{
uniqueKeyColumns := NewColumnList([]string{"position", "name"})
args := []interface{}{"first", 17}
_, _, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args)
test.S(t).ExpectNotNil(err)
}
}
func TestBuildDMLInsertQuery(t *testing.T) {
databaseName := "mydb"
tableName := "tbl"
tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"})
args := []interface{}{3, "testname", "first", 17, 23}
{
sharedColumns := NewColumnList([]string{"id", "name", "position", "age"})
query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNil(err)
expected := `
replace /* gh-osc mydb.tbl */
into mydb.tbl
(id, name, position, age)
values
(?, ?, ?, ?)
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23}))
}
{
sharedColumns := NewColumnList([]string{"position", "name", "age", "id"})
query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNil(err)
expected := `
replace /* gh-osc mydb.tbl */
into mydb.tbl
(position, name, age, id)
values
(?, ?, ?, ?)
`
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{17, "testname", 23, 3}))
}
{
sharedColumns := NewColumnList([]string{"position", "name", "surprise", "id"})
_, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNotNil(err)
}
{
sharedColumns := NewColumnList([]string{})
_, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args)
test.S(t).ExpectNotNil(err)
}
}

View File

@ -11,34 +11,64 @@ import (
"strings" "strings"
) )
// ColumnList makes for a named list of columns
type ColumnList []string
// ParseColumnList parses a comma delimited list of column names
func ParseColumnList(columns string) *ColumnList {
result := ColumnList(strings.Split(columns, ","))
return &result
}
func (this *ColumnList) String() string {
return strings.Join(*this, ",")
}
func (this *ColumnList) Equals(other *ColumnList) bool {
return reflect.DeepEqual(*this, *other)
}
// ColumnsMap maps a column onto its ordinal position // ColumnsMap maps a column onto its ordinal position
type ColumnsMap map[string]int type ColumnsMap map[string]int
func NewColumnsMap(columnList ColumnList) ColumnsMap { func NewColumnsMap(orderedNames []string) ColumnsMap {
columnsMap := make(map[string]int) columnsMap := make(map[string]int)
for i, column := range columnList { for i, column := range orderedNames {
columnsMap[column] = i columnsMap[column] = i
} }
return ColumnsMap(columnsMap) return ColumnsMap(columnsMap)
} }
// ColumnList makes for a named list of columns
type ColumnList struct {
Names []string
Ordinals ColumnsMap
}
// NewColumnList creates an object given ordered list of column names
func NewColumnList(names []string) *ColumnList {
result := &ColumnList{
Names: names,
}
result.Ordinals = NewColumnsMap(result.Names)
return result
}
// ParseColumnList parses a comma delimited list of column names
func ParseColumnList(columns string) *ColumnList {
result := &ColumnList{
Names: strings.Split(columns, ","),
}
result.Ordinals = NewColumnsMap(result.Names)
return result
}
func (this *ColumnList) String() string {
return strings.Join(this.Names, ",")
}
func (this *ColumnList) Equals(other *ColumnList) bool {
return reflect.DeepEqual(this.Names, other.Names)
}
// IsSubsetOf returns 'true' when column names of this list are a subset of
// another list, in arbitrary order (order agnostic)
func (this *ColumnList) IsSubsetOf(other *ColumnList) bool {
for _, column := range this.Names {
if _, exists := other.Ordinals[column]; !exists {
return false
}
}
return true
}
func (this *ColumnList) Len() int {
return len(this.Names)
}
// UniqueKey is the combination of a key's name and columns // UniqueKey is the combination of a key's name and columns
type UniqueKey struct { type UniqueKey struct {
Name string Name string
@ -51,6 +81,10 @@ func (this *UniqueKey) IsPrimary() bool {
return this.Name == "PRIMARY" return this.Name == "PRIMARY"
} }
func (this *UniqueKey) Len() int {
return this.Columns.Len()
}
func (this *UniqueKey) String() string { func (this *UniqueKey) String() string {
return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable) return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable)
} }