diff --git a/build.sh b/build.sh index 947a68d..b930269 100755 --- a/build.sh +++ b/build.sh @@ -2,7 +2,7 @@ # # -RELEASE_VERSION="1.0.23" +RELEASE_VERSION="1.0.24" function build { osname=$1 diff --git a/go/logic/applier.go b/go/logic/applier.go index eb23af8..07aa436 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -315,7 +315,7 @@ func (this *Applier) ExecuteThrottleQuery() (int64, error) { // ReadMigrationMinValues returns the minimum values to be iterated on rowcopy func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names()) + query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &uniqueKey.Columns) if err != nil { return err } @@ -336,7 +336,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { // ReadMigrationMaxValues returns the maximum values to be iterated on rowcopy func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names()) + query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &uniqueKey.Columns) if err != nil { return err } @@ -377,7 +377,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, - this.migrationContext.UniqueKey.Columns.Names(), + &this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(), atomic.LoadInt64(&this.migrationContext.ChunkSize), @@ -419,7 +419,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.SharedColumns.Names(), this.migrationContext.MappedSharedColumns.Names(), this.migrationContext.UniqueKey.Name, - this.migrationContext.UniqueKey.Columns.Names(), + &this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), this.migrationContext.GetIteration() == 0, diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 1b6f7f0..441afec 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -135,7 +135,7 @@ func (this *Inspector) InspectOriginalAndGhostTables() (err error) { // This additional step looks at which columns are unsigned. We could have merged this within // the `getTableColumns()` function, but it's a later patch and introduces some complexity; I feel // comfortable in doing this as a separate step. - this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns) + this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns) this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.GhostTableColumns, this.migrationContext.MappedSharedColumns) for i := range this.migrationContext.SharedColumns.Columns() { @@ -532,6 +532,11 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL columnsList.GetColumn(columnName).Type = sql.DateTimeColumnType } } + if strings.HasPrefix(columnType, "enum") { + for _, columnsList := range columnsLists { + columnsList.GetColumn(columnName).Type = sql.EnumColumnValue + } + } if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { for _, columnsList := range columnsLists { columnsList.SetCharset(columnName, charset) diff --git a/go/sql/builder.go b/go/sql/builder.go index ba587c2..c455992 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -170,12 +170,12 @@ func BuildRangeComparison(columns []string, values []string, args []interface{}, return result, explodedArgs, nil } -func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { - values := buildPreparedValues(len(columns)) - return BuildRangeComparison(columns, values, args, comparisonSign) +func BuildRangePreparedComparison(columns *ColumnList, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { + values := buildColumnsPreparedValues(columns) + return BuildRangeComparison(columns.Names(), values, args, comparisonSign) } -func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { +func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { if len(sharedColumns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery") } @@ -200,12 +200,12 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin if includeRangeStartValues { minRangeComparisonSign = GreaterThanOrEqualsComparisonSign } - rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, rangeStartArgs, minRangeComparisonSign) + rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeStartValues, rangeStartArgs, minRangeComparisonSign) if err != nil { return "", explodedArgs, err } explodedArgs = append(explodedArgs, rangeExplodedArgs...) - rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, rangeEndArgs, LessThanOrEqualsComparisonSign) + rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeEndValues, rangeEndArgs, LessThanOrEqualsComparisonSign) if err != nil { return "", explodedArgs, err } @@ -225,14 +225,14 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin return result, explodedArgs, nil } -func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { - rangeStartValues := buildPreparedValues(len(uniqueKeyColumns)) - rangeEndValues := buildPreparedValues(len(uniqueKeyColumns)) +func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { + rangeStartValues := buildColumnsPreparedValues(uniqueKeyColumns) + rangeEndValues := buildColumnsPreparedValues(uniqueKeyColumns) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) } -func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) { - if len(uniqueKeyColumns) == 0 { +func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) { + if uniqueKeyColumns.Len() == 0 { return "", explodedArgs, fmt.Errorf("Got 0 columns in BuildUniqueKeyRangeEndPreparedQuery") } databaseName = EscapeName(databaseName) @@ -253,13 +253,18 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK } explodedArgs = append(explodedArgs, rangeExplodedArgs...) - uniqueKeyColumns = duplicateNames(uniqueKeyColumns) - uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - for i := range uniqueKeyColumns { - uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) - uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumns[i]) - uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumns[i]) + uniqueKeyColumnNames := duplicateNames(uniqueKeyColumns.Names()) + uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) + uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) + for i, column := range uniqueKeyColumns.Columns() { + uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i]) + if column.Type == EnumColumnValue { + uniqueKeyColumnAscending[i] = fmt.Sprintf("concat(%s) asc", uniqueKeyColumnNames[i]) + uniqueKeyColumnDescending[i] = fmt.Sprintf("concat(%s) desc", uniqueKeyColumnNames[i]) + } else { + uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumnNames[i]) + uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumnNames[i]) + } } result = fmt.Sprintf(` select /* gh-ost %s.%s %s */ %s @@ -276,8 +281,8 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK order by %s limit 1 - `, databaseName, tableName, hint, strings.Join(uniqueKeyColumns, ", "), - strings.Join(uniqueKeyColumns, ", "), databaseName, tableName, + `, databaseName, tableName, hint, strings.Join(uniqueKeyColumnNames, ", "), + strings.Join(uniqueKeyColumnNames, ", "), databaseName, tableName, rangeStartComparison, rangeEndComparison, strings.Join(uniqueKeyColumnAscending, ", "), chunkSize, strings.Join(uniqueKeyColumnDescending, ", "), @@ -285,26 +290,30 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK return result, explodedArgs, nil } -func BuildUniqueKeyMinValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) { +func BuildUniqueKeyMinValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList) (string, error) { return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "asc") } -func BuildUniqueKeyMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) { +func BuildUniqueKeyMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList) (string, error) { return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "desc") } -func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, order string) (string, error) { - if len(uniqueKeyColumns) == 0 { +func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList, order string) (string, error) { + if uniqueKeyColumns.Len() == 0 { return "", fmt.Errorf("Got 0 columns in BuildUniqueKeyMinMaxValuesPreparedQuery") } databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) - uniqueKeyColumns = duplicateNames(uniqueKeyColumns) - uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - for i := range uniqueKeyColumns { - uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) - uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumns[i], order) + uniqueKeyColumnNames := duplicateNames(uniqueKeyColumns.Names()) + uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) + for i, column := range uniqueKeyColumns.Columns() { + uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i]) + if column.Type == EnumColumnValue { + uniqueKeyColumnOrder[i] = fmt.Sprintf("concat(%s) %s", uniqueKeyColumnNames[i], order) + } else { + uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumnNames[i], order) + } } query := fmt.Sprintf(` select /* gh-ost %s.%s */ %s @@ -313,7 +322,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni order by %s limit 1 - `, databaseName, tableName, strings.Join(uniqueKeyColumns, ", "), + `, databaseName, tableName, strings.Join(uniqueKeyColumnNames, ", "), databaseName, tableName, strings.Join(uniqueKeyColumnOrder, ", "), ) diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index 806b77b..46c44e1 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -166,7 +166,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { sharedColumns := []string{"id", "name", "position"} { uniqueKey := "PRIMARY" - uniqueKeyColumns := []string{"id"} + uniqueKeyColumns := NewColumnList([]string{"id"}) rangeStartValues := []string{"@v1s"} rangeEndValues := []string{"@v1e"} rangeStartArgs := []interface{}{3} @@ -185,7 +185,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { } { uniqueKey := "name_position_uidx" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartValues := []string{"@v1s", "@v2s"} rangeEndValues := []string{"@v1e", "@v2e"} rangeStartArgs := []interface{}{3, 17} @@ -212,7 +212,7 @@ func TestBuildRangeInsertQueryRenameMap(t *testing.T) { mappedSharedColumns := []string{"id", "name", "location"} { uniqueKey := "PRIMARY" - uniqueKeyColumns := []string{"id"} + uniqueKeyColumns := NewColumnList([]string{"id"}) rangeStartValues := []string{"@v1s"} rangeEndValues := []string{"@v1e"} rangeStartArgs := []interface{}{3} @@ -231,7 +231,7 @@ func TestBuildRangeInsertQueryRenameMap(t *testing.T) { } { uniqueKey := "name_position_uidx" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartValues := []string{"@v1s", "@v2s"} rangeEndValues := []string{"@v1e", "@v2e"} rangeStartArgs := []interface{}{3, 17} @@ -257,7 +257,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { sharedColumns := []string{"id", "name", "position"} { uniqueKey := "name_position_uidx" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} @@ -279,7 +279,7 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { originalTableName := "tbl" var chunkSize int64 = 500 { - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} @@ -309,7 +309,7 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) { databaseName := "mydb" originalTableName := "tbl" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) { query, err := BuildUniqueKeyMinValuesPreparedQuery(databaseName, originalTableName, uniqueKeyColumns) test.S(t).ExpectNil(err) diff --git a/go/sql/types.go b/go/sql/types.go index 1c57fbb..720f92f 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -18,6 +18,7 @@ const ( UnknownColumnType ColumnType = iota TimestampColumnType = iota DateTimeColumnType = iota + EnumColumnValue = iota ) type TimezoneConvertion struct { diff --git a/localtests/enum-pk/extra_args b/localtests/enum-pk/extra_args deleted file mode 100644 index f369a56..0000000 --- a/localtests/enum-pk/extra_args +++ /dev/null @@ -1 +0,0 @@ ---alter="change e e enum('red', 'green', 'blue', 'orange', 'yellow') null default null collate 'utf8_bin'"