Merge pull request #8 from github/sql-queries-manipulations
merging so I can use this on other branches
This commit is contained in:
commit
9b368a7720
14
go/base/context.go
Normal file
14
go/base/context.go
Normal file
@ -0,0 +1,14 @@
|
||||
/*
|
||||
Copyright 2016 GitHub Inc.
|
||||
See https://github.com/github/gh-osc/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
package base
|
||||
|
||||
import ()
|
||||
|
||||
type MigrationContext struct {
|
||||
DatabaseName string
|
||||
OriginalTableName string
|
||||
GhostTableName string
|
||||
}
|
138
go/sql/builder.go
Normal file
138
go/sql/builder.go
Normal file
@ -0,0 +1,138 @@
|
||||
/*
|
||||
Copyright 2016 GitHub Inc.
|
||||
See https://github.com/github/gh-osc/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ValueComparisonSign string
|
||||
|
||||
const (
|
||||
LessThanComparisonSign ValueComparisonSign = "<"
|
||||
LessThanOrEqualsComparisonSign = "<="
|
||||
EqualsComparisonSign = "="
|
||||
GreaterThanOrEqualsComparisonSign = ">="
|
||||
GreaterThanComparisonSign = ">"
|
||||
NotEqualsComparisonSign = "!="
|
||||
)
|
||||
|
||||
// EscapeName will escape a db/table/column/... name by wrapping with backticks.
|
||||
// It is not fool proof. I'm just trying to do the right thing here, not solving
|
||||
// SQL injection issues, which should be irrelevant for this tool.
|
||||
func EscapeName(name string) string {
|
||||
if unquoted, err := strconv.Unquote(name); err == nil {
|
||||
name = unquoted
|
||||
}
|
||||
return fmt.Sprintf("`%s`", name)
|
||||
}
|
||||
|
||||
func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) {
|
||||
if column == "" {
|
||||
return "", fmt.Errorf("Empty column in GetValueComparison")
|
||||
}
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("Empty value in GetValueComparison")
|
||||
}
|
||||
comparison := fmt.Sprintf("(%s %s %s)", EscapeName(column), string(comparisonSign), value)
|
||||
return comparison, err
|
||||
}
|
||||
|
||||
func BuildEqualsComparison(columns []string, values []string) (result string, err error) {
|
||||
if len(columns) == 0 {
|
||||
return "", fmt.Errorf("Got 0 columns in GetEqualsComparison")
|
||||
}
|
||||
if len(columns) != len(values) {
|
||||
return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
|
||||
}
|
||||
comparisons := []string{}
|
||||
for i, column := range columns {
|
||||
value := values[i]
|
||||
comparison, err := BuildValueComparison(column, value, EqualsComparisonSign)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
comparisons = append(comparisons, comparison)
|
||||
}
|
||||
result = strings.Join(comparisons, " and ")
|
||||
result = fmt.Sprintf("(%s)", result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func BuildRangeComparison(columns []string, values []string, comparisonSign ValueComparisonSign) (result string, err error) {
|
||||
if len(columns) == 0 {
|
||||
return "", fmt.Errorf("Got 0 columns in GetRangeComparison")
|
||||
}
|
||||
if len(columns) != len(values) {
|
||||
return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
|
||||
}
|
||||
includeEquals := false
|
||||
if comparisonSign == LessThanOrEqualsComparisonSign {
|
||||
comparisonSign = LessThanComparisonSign
|
||||
includeEquals = true
|
||||
}
|
||||
if comparisonSign == GreaterThanOrEqualsComparisonSign {
|
||||
comparisonSign = GreaterThanComparisonSign
|
||||
includeEquals = true
|
||||
}
|
||||
comparisons := []string{}
|
||||
|
||||
for i, column := range columns {
|
||||
//
|
||||
value := values[i]
|
||||
rangeComparison, err := BuildValueComparison(column, value, comparisonSign)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(columns[0:i]) > 0 {
|
||||
equalitiesComparison, err := BuildEqualsComparison(columns[0:i], values[0:i])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
comparison := fmt.Sprintf("(%s AND %s)", equalitiesComparison, rangeComparison)
|
||||
comparisons = append(comparisons, comparison)
|
||||
} else {
|
||||
comparisons = append(comparisons, rangeComparison)
|
||||
}
|
||||
}
|
||||
|
||||
if includeEquals {
|
||||
comparison, err := BuildEqualsComparison(columns, values)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
comparisons = append(comparisons, comparison)
|
||||
}
|
||||
result = strings.Join(comparisons, " or ")
|
||||
result = fmt.Sprintf("(%s)", result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string) (string, error) {
|
||||
if len(sharedColumns) == 0 {
|
||||
return "", fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery")
|
||||
}
|
||||
sharedColumnsListing := strings.Join(sharedColumns, ", ")
|
||||
rangeStartComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, GreaterThanOrEqualsComparisonSign)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rangeEndComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, LessThanOrEqualsComparisonSign)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := fmt.Sprintf(`
|
||||
insert /* gh-osc %s.%s */ ignore into %s.%s (%s)
|
||||
(select %s from %s.%s force index (%s)
|
||||
where (%s and %s)
|
||||
)
|
||||
`, databaseName, originalTableName, databaseName, ghostTableName, sharedColumnsListing,
|
||||
sharedColumnsListing, databaseName, originalTableName, uniqueKey,
|
||||
rangeStartComparison, rangeEndComparison)
|
||||
return query, nil
|
||||
}
|
157
go/sql/builder_test.go
Normal file
157
go/sql/builder_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
/*
|
||||
Copyright 2016 GitHub Inc.
|
||||
See https://github.com/github/gh-osc/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/outbrain/golib/log"
|
||||
test "github.com/outbrain/golib/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
spacesRegexp = regexp.MustCompile(`[ \t\n\r]+`)
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.ERROR)
|
||||
}
|
||||
|
||||
func normalizeQuery(name string) string {
|
||||
name = strings.Replace(name, "`", "", -1)
|
||||
name = spacesRegexp.ReplaceAllString(name, " ")
|
||||
name = strings.TrimSpace(name)
|
||||
return name
|
||||
}
|
||||
|
||||
func TestEscapeName(t *testing.T) {
|
||||
names := []string{"my_table", `"my_table"`, "`my_table`"}
|
||||
for _, name := range names {
|
||||
escaped := EscapeName(name)
|
||||
test.S(t).ExpectEquals(escaped, "`my_table`")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEqualsComparison(t *testing.T) {
|
||||
{
|
||||
columns := []string{"c1"}
|
||||
values := []string{"@v1"}
|
||||
comparison, err := BuildEqualsComparison(columns, values)
|
||||
test.S(t).ExpectNil(err)
|
||||
test.S(t).ExpectEquals(comparison, "((`c1` = @v1))")
|
||||
}
|
||||
{
|
||||
columns := []string{"c1", "c2"}
|
||||
values := []string{"@v1", "@v2"}
|
||||
comparison, err := BuildEqualsComparison(columns, values)
|
||||
test.S(t).ExpectNil(err)
|
||||
test.S(t).ExpectEquals(comparison, "((`c1` = @v1) and (`c2` = @v2))")
|
||||
}
|
||||
{
|
||||
columns := []string{"c1"}
|
||||
values := []string{"@v1", "@v2"}
|
||||
_, err := BuildEqualsComparison(columns, values)
|
||||
test.S(t).ExpectNotNil(err)
|
||||
}
|
||||
{
|
||||
columns := []string{}
|
||||
values := []string{}
|
||||
_, err := BuildEqualsComparison(columns, values)
|
||||
test.S(t).ExpectNotNil(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRangeComparison(t *testing.T) {
|
||||
{
|
||||
columns := []string{"c1"}
|
||||
values := []string{"@v1"}
|
||||
comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign)
|
||||
test.S(t).ExpectNil(err)
|
||||
test.S(t).ExpectEquals(comparison, "((`c1` < @v1))")
|
||||
}
|
||||
{
|
||||
columns := []string{"c1"}
|
||||
values := []string{"@v1"}
|
||||
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
|
||||
test.S(t).ExpectNil(err)
|
||||
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or ((`c1` = @v1)))")
|
||||
}
|
||||
{
|
||||
columns := []string{"c1", "c2"}
|
||||
values := []string{"@v1", "@v2"}
|
||||
comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign)
|
||||
test.S(t).ExpectNil(err)
|
||||
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)))")
|
||||
}
|
||||
{
|
||||
columns := []string{"c1", "c2"}
|
||||
values := []string{"@v1", "@v2"}
|
||||
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
|
||||
test.S(t).ExpectNil(err)
|
||||
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or ((`c1` = @v1) and (`c2` = @v2)))")
|
||||
}
|
||||
{
|
||||
columns := []string{"c1", "c2", "c3"}
|
||||
values := []string{"@v1", "@v2", "@v3"}
|
||||
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
|
||||
test.S(t).ExpectNil(err)
|
||||
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or (((`c1` = @v1) and (`c2` = @v2)) AND (`c3` < @v3)) or ((`c1` = @v1) and (`c2` = @v2) and (`c3` = @v3)))")
|
||||
}
|
||||
{
|
||||
columns := []string{"c1"}
|
||||
values := []string{"@v1", "@v2"}
|
||||
_, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
|
||||
test.S(t).ExpectNotNil(err)
|
||||
}
|
||||
{
|
||||
columns := []string{}
|
||||
values := []string{}
|
||||
_, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
|
||||
test.S(t).ExpectNotNil(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRangeInsertQuery(t *testing.T) {
|
||||
databaseName := "mydb"
|
||||
originalTableName := "tbl"
|
||||
ghostTableName := "ghost"
|
||||
sharedColumns := []string{"id", "name", "position"}
|
||||
{
|
||||
uniqueKey := "PRIMARY"
|
||||
uniqueKeyColumns := []string{"id"}
|
||||
rangeStartValues := []string{"@v1s"}
|
||||
rangeEndValues := []string{"@v1e"}
|
||||
|
||||
query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues)
|
||||
test.S(t).ExpectNil(err)
|
||||
expected := `
|
||||
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
|
||||
(select id, name, position from mydb.tbl force index (PRIMARY)
|
||||
where (((id > @v1s) or ((id = @v1s))) and ((id < @v1e) or ((id = @v1e))))
|
||||
)
|
||||
`
|
||||
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
|
||||
}
|
||||
{
|
||||
uniqueKey := "name_position_uidx"
|
||||
uniqueKeyColumns := []string{"name", "position"}
|
||||
rangeStartValues := []string{"@v1s", "@v2s"}
|
||||
rangeEndValues := []string{"@v1e", "@v2e"}
|
||||
|
||||
query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues)
|
||||
test.S(t).ExpectNil(err)
|
||||
expected := `
|
||||
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
|
||||
(select id, name, position from mydb.tbl force index (name_position_uidx)
|
||||
where (((name > @v1s) or (((name = @v1s)) AND (position > @v2s)) or ((name = @v1s) and (position = @v2s))) and ((name < @v1e) or (((name = @v1e)) AND (position < @v2e)) or ((name = @v1e) and (position = @v2e))))
|
||||
)
|
||||
`
|
||||
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user