diff --git a/go/logic/inspect.go b/go/logic/inspect.go index ce1fcfb..c9744a8 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -173,8 +173,7 @@ func (this *Inspector) inspectOriginalAndGhostTables() (err error) { // This additional step looks at which columns are unsigned. We could have merged this within // the `getTableColumns()` function, but it's a later patch and introduces some complexity; I feel // comfortable in doing this as a separate step. - this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns) - this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &this.migrationContext.UniqueKey.Columns) + this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns) this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.GhostTableColumns, this.migrationContext.MappedSharedColumns) for i := range this.migrationContext.SharedColumns.Columns() { @@ -552,44 +551,35 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { columnName := m.GetString("COLUMN_NAME") columnType := m.GetString("COLUMN_TYPE") - if strings.Contains(columnType, "unsigned") { - for _, columnsList := range columnsLists { - columnsList.SetUnsigned(columnName) + for _, columnsList := range columnsLists { + column := columnsList.GetColumn(columnName) + if column == nil { + continue } - } - if strings.Contains(columnType, "mediumint") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.MediumIntColumnType + + if strings.Contains(columnType, "unsigned") { + column.IsUnsigned = true } - } - if strings.Contains(columnType, "timestamp") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.TimestampColumnType + if strings.Contains(columnType, "mediumint") { + column.Type = sql.MediumIntColumnType } - } - if strings.Contains(columnType, "datetime") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.DateTimeColumnType + if strings.Contains(columnType, "timestamp") { + column.Type = sql.TimestampColumnType } - } - if strings.Contains(columnType, "json") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.JSONColumnType + if strings.Contains(columnType, "datetime") { + column.Type = sql.DateTimeColumnType } - } - if strings.Contains(columnType, "float") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.FloatColumnType + if strings.Contains(columnType, "json") { + column.Type = sql.JSONColumnType } - } - if strings.HasPrefix(columnType, "enum") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.EnumColumnType + if strings.Contains(columnType, "float") { + column.Type = sql.FloatColumnType } - } - if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { - for _, columnsList := range columnsLists { - columnsList.SetCharset(columnName, charset) + if strings.HasPrefix(columnType, "enum") { + column.Type = sql.EnumColumnType + } + if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { + column.Charset = charset } } return nil