diff --git a/go/base/context.go b/go/base/context.go new file mode 100644 index 0000000..f238878 --- /dev/null +++ b/go/base/context.go @@ -0,0 +1,14 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package base + +import () + +type MigrationContext struct { + DatabaseName string + OriginalTableName string + GhostTableName string +} diff --git a/go/sql/builder.go b/go/sql/builder.go new file mode 100644 index 0000000..d1086bb --- /dev/null +++ b/go/sql/builder.go @@ -0,0 +1,138 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package sql + +import ( + "fmt" + "strconv" + "strings" +) + +type ValueComparisonSign string + +const ( + LessThanComparisonSign ValueComparisonSign = "<" + LessThanOrEqualsComparisonSign = "<=" + EqualsComparisonSign = "=" + GreaterThanOrEqualsComparisonSign = ">=" + GreaterThanComparisonSign = ">" + NotEqualsComparisonSign = "!=" +) + +// EscapeName will escape a db/table/column/... name by wrapping with backticks. +// It is not fool proof. I'm just trying to do the right thing here, not solving +// SQL injection issues, which should be irrelevant for this tool. +func EscapeName(name string) string { + if unquoted, err := strconv.Unquote(name); err == nil { + name = unquoted + } + return fmt.Sprintf("`%s`", name) +} + +func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) { + if column == "" { + return "", fmt.Errorf("Empty column in GetValueComparison") + } + if value == "" { + return "", fmt.Errorf("Empty value in GetValueComparison") + } + comparison := fmt.Sprintf("(%s %s %s)", EscapeName(column), string(comparisonSign), value) + return comparison, err +} + +func BuildEqualsComparison(columns []string, values []string) (result string, err error) { + if len(columns) == 0 { + return "", fmt.Errorf("Got 0 columns in GetEqualsComparison") + } + if len(columns) != len(values) { + return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values)) + } + comparisons := []string{} + for i, column := range columns { + value := values[i] + comparison, err := BuildValueComparison(column, value, EqualsComparisonSign) + if err != nil { + return "", err + } + comparisons = append(comparisons, comparison) + } + result = strings.Join(comparisons, " and ") + result = fmt.Sprintf("(%s)", result) + return result, nil +} + +func BuildRangeComparison(columns []string, values []string, comparisonSign ValueComparisonSign) (result string, err error) { + if len(columns) == 0 { + return "", fmt.Errorf("Got 0 columns in GetRangeComparison") + } + if len(columns) != len(values) { + return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values)) + } + includeEquals := false + if comparisonSign == LessThanOrEqualsComparisonSign { + comparisonSign = LessThanComparisonSign + includeEquals = true + } + if comparisonSign == GreaterThanOrEqualsComparisonSign { + comparisonSign = GreaterThanComparisonSign + includeEquals = true + } + comparisons := []string{} + + for i, column := range columns { + // + value := values[i] + rangeComparison, err := BuildValueComparison(column, value, comparisonSign) + if err != nil { + return "", err + } + if len(columns[0:i]) > 0 { + equalitiesComparison, err := BuildEqualsComparison(columns[0:i], values[0:i]) + if err != nil { + return "", err + } + comparison := fmt.Sprintf("(%s AND %s)", equalitiesComparison, rangeComparison) + comparisons = append(comparisons, comparison) + } else { + comparisons = append(comparisons, rangeComparison) + } + } + + if includeEquals { + comparison, err := BuildEqualsComparison(columns, values) + if err != nil { + return "", nil + } + comparisons = append(comparisons, comparison) + } + result = strings.Join(comparisons, " or ") + result = fmt.Sprintf("(%s)", result) + return result, nil +} + +func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string) (string, error) { + if len(sharedColumns) == 0 { + return "", fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery") + } + sharedColumnsListing := strings.Join(sharedColumns, ", ") + rangeStartComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, GreaterThanOrEqualsComparisonSign) + if err != nil { + return "", err + } + rangeEndComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, LessThanOrEqualsComparisonSign) + if err != nil { + return "", err + } + query := fmt.Sprintf(` + insert /* gh-osc %s.%s */ ignore into %s.%s (%s) + (select %s from %s.%s force index (%s) + where (%s and %s) + ) + `, databaseName, originalTableName, databaseName, ghostTableName, sharedColumnsListing, + sharedColumnsListing, databaseName, originalTableName, uniqueKey, + rangeStartComparison, rangeEndComparison) + return query, nil +} diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go new file mode 100644 index 0000000..9994842 --- /dev/null +++ b/go/sql/builder_test.go @@ -0,0 +1,157 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package sql + +import ( + "testing" + + "regexp" + "strings" + + "github.com/outbrain/golib/log" + test "github.com/outbrain/golib/tests" +) + +var ( + spacesRegexp = regexp.MustCompile(`[ \t\n\r]+`) +) + +func init() { + log.SetLevel(log.ERROR) +} + +func normalizeQuery(name string) string { + name = strings.Replace(name, "`", "", -1) + name = spacesRegexp.ReplaceAllString(name, " ") + name = strings.TrimSpace(name) + return name +} + +func TestEscapeName(t *testing.T) { + names := []string{"my_table", `"my_table"`, "`my_table`"} + for _, name := range names { + escaped := EscapeName(name) + test.S(t).ExpectEquals(escaped, "`my_table`") + } +} + +func TestBuildEqualsComparison(t *testing.T) { + { + columns := []string{"c1"} + values := []string{"@v1"} + comparison, err := BuildEqualsComparison(columns, values) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` = @v1))") + } + { + columns := []string{"c1", "c2"} + values := []string{"@v1", "@v2"} + comparison, err := BuildEqualsComparison(columns, values) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` = @v1) and (`c2` = @v2))") + } + { + columns := []string{"c1"} + values := []string{"@v1", "@v2"} + _, err := BuildEqualsComparison(columns, values) + test.S(t).ExpectNotNil(err) + } + { + columns := []string{} + values := []string{} + _, err := BuildEqualsComparison(columns, values) + test.S(t).ExpectNotNil(err) + } +} + +func TestBuildRangeComparison(t *testing.T) { + { + columns := []string{"c1"} + values := []string{"@v1"} + comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` < @v1))") + } + { + columns := []string{"c1"} + values := []string{"@v1"} + comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or ((`c1` = @v1)))") + } + { + columns := []string{"c1", "c2"} + values := []string{"@v1", "@v2"} + comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)))") + } + { + columns := []string{"c1", "c2"} + values := []string{"@v1", "@v2"} + comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or ((`c1` = @v1) and (`c2` = @v2)))") + } + { + columns := []string{"c1", "c2", "c3"} + values := []string{"@v1", "@v2", "@v3"} + comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or (((`c1` = @v1) and (`c2` = @v2)) AND (`c3` < @v3)) or ((`c1` = @v1) and (`c2` = @v2) and (`c3` = @v3)))") + } + { + columns := []string{"c1"} + values := []string{"@v1", "@v2"} + _, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + test.S(t).ExpectNotNil(err) + } + { + columns := []string{} + values := []string{} + _, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign) + test.S(t).ExpectNotNil(err) + } +} + +func TestBuildRangeInsertQuery(t *testing.T) { + databaseName := "mydb" + originalTableName := "tbl" + ghostTableName := "ghost" + sharedColumns := []string{"id", "name", "position"} + { + uniqueKey := "PRIMARY" + uniqueKeyColumns := []string{"id"} + rangeStartValues := []string{"@v1s"} + rangeEndValues := []string{"@v1e"} + + query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues) + test.S(t).ExpectNil(err) + expected := ` + insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) + (select id, name, position from mydb.tbl force index (PRIMARY) + where (((id > @v1s) or ((id = @v1s))) and ((id < @v1e) or ((id = @v1e)))) + ) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + } + { + uniqueKey := "name_position_uidx" + uniqueKeyColumns := []string{"name", "position"} + rangeStartValues := []string{"@v1s", "@v2s"} + rangeEndValues := []string{"@v1e", "@v2e"} + + query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues) + test.S(t).ExpectNil(err) + expected := ` + insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) + (select id, name, position from mydb.tbl force index (name_position_uidx) + where (((name > @v1s) or (((name = @v1s)) AND (position > @v2s)) or ((name = @v1s) and (position = @v2s))) and ((name < @v1e) or (((name = @v1e)) AND (position < @v2e)) or ((name = @v1e) and (position = @v2e)))) + ) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + } +}