diff --git a/go/logic/inspect.go b/go/logic/inspect.go index af84b4c..3a9cc39 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -56,9 +56,6 @@ func (this *Inspector) InitDBConnections() (err error) { if err := this.applyBinlogFormat(); err != nil { return err } - if err := this.validateAndReadTimeZone(); err != nil { - return err - } return nil } @@ -141,6 +138,14 @@ func (this *Inspector) InspectOriginalAndGhostTables() (err error) { this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns) this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.GhostTableColumns, this.migrationContext.MappedSharedColumns) + for i := range this.migrationContext.SharedColumns.Columns() { + column := this.migrationContext.SharedColumns.Columns()[i] + mappedColumn := this.migrationContext.MappedSharedColumns.Columns()[i] + if column.Name == mappedColumn.Name && column.Type == sql.DateTimeColumnType && mappedColumn.Type == sql.TimestampColumnType { + this.migrationContext.MappedSharedColumns.SetConvertDatetimeToTimestamp(column.Name, this.migrationContext.ApplierTimeZone) + } + } + return nil } @@ -158,18 +163,6 @@ func (this *Inspector) validateConnection() error { return nil } -// validateAndReadTimeZone potentially reads server time-zone -func (this *Inspector) validateAndReadTimeZone() error { - if this.migrationContext.InspectorTimeZone == "" { - query := `select @@global.time_zone` - if err := this.db.QueryRow(query).Scan(&this.migrationContext.InspectorTimeZone); err != nil { - return err - } - } - log.Infof("will use time_zone='%s' on inspector", this.migrationContext.InspectorTimeZone) - return nil -} - // validateGrants verifies the user by which we're executing has necessary grants // to do its thang. func (this *Inspector) validateGrants() error { @@ -517,11 +510,22 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL ` err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { columnName := m.GetString("COLUMN_NAME") - if strings.Contains(m.GetString("COLUMN_TYPE"), "unsigned") { + columnType := m.GetString("COLUMN_TYPE") + if strings.Contains(columnType, "unsigned") { for _, columnsList := range columnsLists { columnsList.SetUnsigned(columnName) } } + if strings.Contains(columnType, "timestamp") { + for _, columnsList := range columnsLists { + columnsList.GetColumn(columnName).Type = sql.TimestampColumnType + } + } + if strings.Contains(columnType, "datetime") { + for _, columnsList := range columnsLists { + columnsList.GetColumn(columnName).Type = sql.DateTimeColumnType + } + } if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { for _, columnsList := range columnsLists { columnsList.SetCharset(columnName, charset)