diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 4ad0131..7b8002d 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -592,6 +592,7 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL } if strings.HasPrefix(columnType, "enum") { column.Type = sql.EnumColumnType + column.EnumValues = sql.ParseEnumValues(m.GetString("COLUMN_TYPE")) } if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { column.Charset = charset diff --git a/go/sql/builder.go b/go/sql/builder.go index 4618ef1..878850c 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -39,7 +39,7 @@ func buildColumnsPreparedValues(columns *ColumnList) []string { if column.timezoneConversion != nil { token = fmt.Sprintf("convert_tz(?, '%s', '%s')", column.timezoneConversion.ToTimezone, "+00:00") } else if column.enumToTextConversion { - token = "concat('', ?)" + token = fmt.Sprintf("ELT(?, %s)", column.EnumValues) } else if column.Type == JSONColumnType { token = "convert(? using utf8mb4)" } else { @@ -111,7 +111,7 @@ func BuildSetPreparedClause(columns *ColumnList) (result string, err error) { if column.timezoneConversion != nil { setToken = fmt.Sprintf("%s=convert_tz(?, '%s', '%s')", EscapeName(column.Name), column.timezoneConversion.ToTimezone, "+00:00") } else if column.enumToTextConversion { - setToken = fmt.Sprintf("%s=concat('', ?)", EscapeName(column.Name)) + setToken = fmt.Sprintf("%s=ELT(?, %s)", EscapeName(column.Name), column.EnumValues) } else if column.Type == JSONColumnType { setToken = fmt.Sprintf("%s=convert(? using utf8mb4)", EscapeName(column.Name)) } else { diff --git a/go/sql/parser.go b/go/sql/parser.go index d9c0c3f..eac0bdc 100644 --- a/go/sql/parser.go +++ b/go/sql/parser.go @@ -33,6 +33,7 @@ var ( // ALTER TABLE tbl something regexp.MustCompile(`(?i)\balter\s+table\s+([\S]+)\s+(.*$)`), } + enumValuesRegexp = regexp.MustCompile("^enum[(](.*)[)]$") ) type AlterTableParser struct { @@ -205,3 +206,10 @@ func (this *AlterTableParser) HasExplicitTable() bool { func (this *AlterTableParser) GetAlterStatementOptions() string { return this.alterStatementOptions } + +func ParseEnumValues(enumColumnType string) string { + if submatch := enumValuesRegexp.FindStringSubmatch(enumColumnType); len(submatch) > 0 { + return submatch[1] + } + return enumColumnType +} diff --git a/go/sql/parser_test.go b/go/sql/parser_test.go index 6cdbb39..3157d09 100644 --- a/go/sql/parser_test.go +++ b/go/sql/parser_test.go @@ -322,3 +322,21 @@ func TestParseAlterStatementExplicitTable(t *testing.T) { test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b", "add index idx(i)"})) } } + +func TestParseEnumValues(t *testing.T) { + { + s := "enum('red','green','blue','orange')" + values := ParseEnumValues(s) + test.S(t).ExpectEquals(values, "'red','green','blue','orange'") + } + { + s := "('red','green','blue','orange')" + values := ParseEnumValues(s) + test.S(t).ExpectEquals(values, "('red','green','blue','orange')") + } + { + s := "zzz" + values := ParseEnumValues(s) + test.S(t).ExpectEquals(values, "zzz") + } +} diff --git a/go/sql/types.go b/go/sql/types.go index 30db646..44e9725 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -35,6 +35,7 @@ type Column struct { IsUnsigned bool Charset string Type ColumnType + EnumValues string timezoneConversion *TimezoneConversion enumToTextConversion bool }