512 lines
20 KiB
Go
512 lines
20 KiB
Go
/*
|
|
Copyright 2016 GitHub Inc.
|
|
See https://github.com/github/gh-ost/blob/master/LICENSE
|
|
*/
|
|
|
|
package sql
|
|
|
|
import (
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type ValueComparisonSign string
|
|
|
|
const (
|
|
LessThanComparisonSign ValueComparisonSign = "<"
|
|
LessThanOrEqualsComparisonSign = "<="
|
|
EqualsComparisonSign = "="
|
|
GreaterThanOrEqualsComparisonSign = ">="
|
|
GreaterThanComparisonSign = ">"
|
|
NotEqualsComparisonSign = "!="
|
|
)
|
|
|
|
// EscapeName will escape a db/table/column/... name by wrapping with backticks.
|
|
// It is not fool proof. I'm just trying to do the right thing here, not solving
|
|
// SQL injection issues, which should be irrelevant for this tool.
|
|
func EscapeName(name string) string {
|
|
if unquoted, err := strconv.Unquote(name); err == nil {
|
|
name = unquoted
|
|
}
|
|
return fmt.Sprintf("`%s`", name)
|
|
}
|
|
|
|
func buildColumnsPreparedValues(columns *ColumnList) []string {
|
|
values := make([]string, columns.Len(), columns.Len())
|
|
for i, column := range columns.Columns() {
|
|
var token string
|
|
if column.timezoneConversion != nil {
|
|
token = fmt.Sprintf("convert_tz(?, '%s', '%s')", column.timezoneConversion.ToTimezone, "+00:00")
|
|
} else if column.Type == JSONColumnType {
|
|
token = "convert(? using utf8mb4)"
|
|
} else {
|
|
token = "?"
|
|
}
|
|
values[i] = token
|
|
}
|
|
return values
|
|
}
|
|
|
|
func buildPreparedValues(length int) []string {
|
|
values := make([]string, length, length)
|
|
for i := 0; i < length; i++ {
|
|
values[i] = "?"
|
|
}
|
|
return values
|
|
}
|
|
|
|
func duplicateNames(names []string) []string {
|
|
duplicate := make([]string, len(names), len(names))
|
|
copy(duplicate, names)
|
|
return duplicate
|
|
}
|
|
|
|
func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) {
|
|
if column == "" {
|
|
return "", fmt.Errorf("Empty column in GetValueComparison")
|
|
}
|
|
if value == "" {
|
|
return "", fmt.Errorf("Empty value in GetValueComparison")
|
|
}
|
|
comparison := fmt.Sprintf("(%s %s %s)", EscapeName(column), string(comparisonSign), value)
|
|
return comparison, err
|
|
}
|
|
|
|
func BuildEqualsComparison(columns []string, values []string) (result string, err error) {
|
|
if len(columns) == 0 {
|
|
return "", fmt.Errorf("Got 0 columns in GetEqualsComparison")
|
|
}
|
|
if len(columns) != len(values) {
|
|
return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
|
|
}
|
|
comparisons := []string{}
|
|
for i, column := range columns {
|
|
value := values[i]
|
|
comparison, err := BuildValueComparison(column, value, EqualsComparisonSign)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
comparisons = append(comparisons, comparison)
|
|
}
|
|
result = strings.Join(comparisons, " and ")
|
|
result = fmt.Sprintf("(%s)", result)
|
|
return result, nil
|
|
}
|
|
|
|
func BuildEqualsPreparedComparison(columns []string) (result string, err error) {
|
|
values := buildPreparedValues(len(columns))
|
|
return BuildEqualsComparison(columns, values)
|
|
}
|
|
|
|
func BuildSetPreparedClause(columns *ColumnList) (result string, err error) {
|
|
if columns.Len() == 0 {
|
|
return "", fmt.Errorf("Got 0 columns in BuildSetPreparedClause")
|
|
}
|
|
setTokens := []string{}
|
|
for _, column := range columns.Columns() {
|
|
var setToken string
|
|
if column.timezoneConversion != nil {
|
|
setToken = fmt.Sprintf("%s=convert_tz(?, '%s', '%s')", EscapeName(column.Name), column.timezoneConversion.ToTimezone, "+00:00")
|
|
} else if column.Type == JSONColumnType {
|
|
setToken = fmt.Sprintf("%s=convert(? using utf8mb4)", EscapeName(column.Name))
|
|
} else {
|
|
setToken = fmt.Sprintf("%s=?", EscapeName(column.Name))
|
|
}
|
|
setTokens = append(setTokens, setToken)
|
|
}
|
|
return strings.Join(setTokens, ", "), nil
|
|
}
|
|
|
|
func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
|
|
if len(columns) == 0 {
|
|
return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison")
|
|
}
|
|
if len(columns) != len(values) {
|
|
return "", explodedArgs, fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
|
|
}
|
|
if len(columns) != len(args) {
|
|
return "", explodedArgs, fmt.Errorf("Got %d columns but %d args in GetEqualsComparison", len(columns), len(args))
|
|
}
|
|
includeEquals := false
|
|
if comparisonSign == LessThanOrEqualsComparisonSign {
|
|
comparisonSign = LessThanComparisonSign
|
|
includeEquals = true
|
|
}
|
|
if comparisonSign == GreaterThanOrEqualsComparisonSign {
|
|
comparisonSign = GreaterThanComparisonSign
|
|
includeEquals = true
|
|
}
|
|
comparisons := []string{}
|
|
|
|
for i, column := range columns {
|
|
//
|
|
value := values[i]
|
|
rangeComparison, err := BuildValueComparison(column, value, comparisonSign)
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
if len(columns[0:i]) > 0 {
|
|
equalitiesComparison, err := BuildEqualsComparison(columns[0:i], values[0:i])
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
comparison := fmt.Sprintf("(%s AND %s)", equalitiesComparison, rangeComparison)
|
|
comparisons = append(comparisons, comparison)
|
|
explodedArgs = append(explodedArgs, args[0:i]...)
|
|
explodedArgs = append(explodedArgs, args[i])
|
|
} else {
|
|
comparisons = append(comparisons, rangeComparison)
|
|
explodedArgs = append(explodedArgs, args[i])
|
|
}
|
|
}
|
|
|
|
if includeEquals {
|
|
comparison, err := BuildEqualsComparison(columns, values)
|
|
if err != nil {
|
|
return "", explodedArgs, nil
|
|
}
|
|
comparisons = append(comparisons, comparison)
|
|
explodedArgs = append(explodedArgs, args...)
|
|
}
|
|
result = strings.Join(comparisons, " or ")
|
|
result = fmt.Sprintf("(%s)", result)
|
|
return result, explodedArgs, nil
|
|
}
|
|
|
|
func BuildRangePreparedComparison(columns *ColumnList, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) {
|
|
values := buildColumnsPreparedValues(columns)
|
|
return BuildRangeComparison(columns.Names(), values, args, comparisonSign)
|
|
}
|
|
|
|
func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) {
|
|
if len(sharedColumns) == 0 {
|
|
return "", explodedArgs, fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery")
|
|
}
|
|
databaseName = EscapeName(databaseName)
|
|
originalTableName = EscapeName(originalTableName)
|
|
ghostTableName = EscapeName(ghostTableName)
|
|
|
|
mappedSharedColumns = duplicateNames(mappedSharedColumns)
|
|
for i := range mappedSharedColumns {
|
|
mappedSharedColumns[i] = EscapeName(mappedSharedColumns[i])
|
|
}
|
|
mappedSharedColumnsListing := strings.Join(mappedSharedColumns, ", ")
|
|
|
|
sharedColumns = duplicateNames(sharedColumns)
|
|
for i := range sharedColumns {
|
|
sharedColumns[i] = EscapeName(sharedColumns[i])
|
|
}
|
|
sharedColumnsListing := strings.Join(sharedColumns, ", ")
|
|
|
|
uniqueKey = EscapeName(uniqueKey)
|
|
var minRangeComparisonSign ValueComparisonSign = GreaterThanComparisonSign
|
|
if includeRangeStartValues {
|
|
minRangeComparisonSign = GreaterThanOrEqualsComparisonSign
|
|
}
|
|
rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeStartValues, rangeStartArgs, minRangeComparisonSign)
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
|
|
rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeEndValues, rangeEndArgs, LessThanOrEqualsComparisonSign)
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
|
|
transactionalClause := ""
|
|
if transactionalTable {
|
|
transactionalClause = "lock in share mode"
|
|
}
|
|
result = fmt.Sprintf(`
|
|
insert /* gh-ost %s.%s */ ignore into %s.%s (%s)
|
|
(select %s from %s.%s force index (%s)
|
|
where (%s and %s) %s
|
|
)
|
|
`, databaseName, originalTableName, databaseName, ghostTableName, mappedSharedColumnsListing,
|
|
sharedColumnsListing, databaseName, originalTableName, uniqueKey,
|
|
rangeStartComparison, rangeEndComparison, transactionalClause)
|
|
return result, explodedArgs, nil
|
|
}
|
|
|
|
func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) {
|
|
rangeStartValues := buildColumnsPreparedValues(uniqueKeyColumns)
|
|
rangeEndValues := buildColumnsPreparedValues(uniqueKeyColumns)
|
|
return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable)
|
|
}
|
|
|
|
func BuildUniqueKeyRangeEndPreparedQueryViaOffset(databaseName, tableName string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) {
|
|
if uniqueKeyColumns.Len() == 0 {
|
|
return "", explodedArgs, fmt.Errorf("Got 0 columns in BuildUniqueKeyRangeEndPreparedQuery")
|
|
}
|
|
databaseName = EscapeName(databaseName)
|
|
tableName = EscapeName(tableName)
|
|
|
|
var startRangeComparisonSign ValueComparisonSign = GreaterThanComparisonSign
|
|
if includeRangeStartValues {
|
|
startRangeComparisonSign = GreaterThanOrEqualsComparisonSign
|
|
}
|
|
rangeStartComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeStartArgs, startRangeComparisonSign)
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
|
|
rangeEndComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeEndArgs, LessThanOrEqualsComparisonSign)
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
|
|
|
|
uniqueKeyColumnNames := duplicateNames(uniqueKeyColumns.Names())
|
|
uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames))
|
|
uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames))
|
|
for i, column := range uniqueKeyColumns.Columns() {
|
|
uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i])
|
|
if column.Type == EnumColumnType {
|
|
uniqueKeyColumnAscending[i] = fmt.Sprintf("concat(%s) asc", uniqueKeyColumnNames[i])
|
|
uniqueKeyColumnDescending[i] = fmt.Sprintf("concat(%s) desc", uniqueKeyColumnNames[i])
|
|
} else {
|
|
uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumnNames[i])
|
|
uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumnNames[i])
|
|
}
|
|
}
|
|
result = fmt.Sprintf(`
|
|
select /* gh-ost %s.%s %s */
|
|
%s
|
|
from
|
|
%s.%s
|
|
where %s and %s
|
|
order by
|
|
%s
|
|
limit 1
|
|
offset %d
|
|
`, databaseName, tableName, hint,
|
|
strings.Join(uniqueKeyColumnNames, ", "),
|
|
databaseName, tableName,
|
|
rangeStartComparison, rangeEndComparison,
|
|
strings.Join(uniqueKeyColumnAscending, ", "),
|
|
(chunkSize - 1),
|
|
)
|
|
return result, explodedArgs, nil
|
|
}
|
|
|
|
func BuildUniqueKeyRangeEndPreparedQueryViaTemptable(databaseName, tableName string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) {
|
|
if uniqueKeyColumns.Len() == 0 {
|
|
return "", explodedArgs, fmt.Errorf("Got 0 columns in BuildUniqueKeyRangeEndPreparedQuery")
|
|
}
|
|
databaseName = EscapeName(databaseName)
|
|
tableName = EscapeName(tableName)
|
|
|
|
var startRangeComparisonSign ValueComparisonSign = GreaterThanComparisonSign
|
|
if includeRangeStartValues {
|
|
startRangeComparisonSign = GreaterThanOrEqualsComparisonSign
|
|
}
|
|
rangeStartComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeStartArgs, startRangeComparisonSign)
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
|
|
rangeEndComparison, rangeExplodedArgs, err := BuildRangePreparedComparison(uniqueKeyColumns, rangeEndArgs, LessThanOrEqualsComparisonSign)
|
|
if err != nil {
|
|
return "", explodedArgs, err
|
|
}
|
|
explodedArgs = append(explodedArgs, rangeExplodedArgs...)
|
|
|
|
uniqueKeyColumnNames := duplicateNames(uniqueKeyColumns.Names())
|
|
uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames))
|
|
uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames))
|
|
for i, column := range uniqueKeyColumns.Columns() {
|
|
uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i])
|
|
if column.Type == EnumColumnType {
|
|
uniqueKeyColumnAscending[i] = fmt.Sprintf("concat(%s) asc", uniqueKeyColumnNames[i])
|
|
uniqueKeyColumnDescending[i] = fmt.Sprintf("concat(%s) desc", uniqueKeyColumnNames[i])
|
|
} else {
|
|
uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumnNames[i])
|
|
uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumnNames[i])
|
|
}
|
|
}
|
|
result = fmt.Sprintf(`
|
|
select /* gh-ost %s.%s %s */ %s
|
|
from (
|
|
select
|
|
%s
|
|
from
|
|
%s.%s
|
|
where %s and %s
|
|
order by
|
|
%s
|
|
limit %d
|
|
) select_osc_chunk
|
|
order by
|
|
%s
|
|
limit 1
|
|
`, databaseName, tableName, hint, strings.Join(uniqueKeyColumnNames, ", "),
|
|
strings.Join(uniqueKeyColumnNames, ", "), databaseName, tableName,
|
|
rangeStartComparison, rangeEndComparison,
|
|
strings.Join(uniqueKeyColumnAscending, ", "), chunkSize,
|
|
strings.Join(uniqueKeyColumnDescending, ", "),
|
|
)
|
|
return result, explodedArgs, nil
|
|
}
|
|
|
|
func BuildUniqueKeyMinValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList) (string, error) {
|
|
return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "asc")
|
|
}
|
|
|
|
func BuildUniqueKeyMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList) (string, error) {
|
|
return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "desc")
|
|
}
|
|
|
|
func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList, order string) (string, error) {
|
|
if uniqueKeyColumns.Len() == 0 {
|
|
return "", fmt.Errorf("Got 0 columns in BuildUniqueKeyMinMaxValuesPreparedQuery")
|
|
}
|
|
databaseName = EscapeName(databaseName)
|
|
tableName = EscapeName(tableName)
|
|
|
|
uniqueKeyColumnNames := duplicateNames(uniqueKeyColumns.Names())
|
|
uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames))
|
|
for i, column := range uniqueKeyColumns.Columns() {
|
|
uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i])
|
|
if column.Type == EnumColumnType {
|
|
uniqueKeyColumnOrder[i] = fmt.Sprintf("concat(%s) %s", uniqueKeyColumnNames[i], order)
|
|
} else {
|
|
uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumnNames[i], order)
|
|
}
|
|
}
|
|
query := fmt.Sprintf(`
|
|
select /* gh-ost %s.%s */ %s
|
|
from
|
|
%s.%s
|
|
order by
|
|
%s
|
|
limit 1
|
|
`, databaseName, tableName, strings.Join(uniqueKeyColumnNames, ", "),
|
|
databaseName, tableName,
|
|
strings.Join(uniqueKeyColumnOrder, ", "),
|
|
)
|
|
return query, nil
|
|
}
|
|
|
|
func BuildDMLDeleteQuery(databaseName, tableName string, tableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) {
|
|
if len(args) != tableColumns.Len() {
|
|
return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery")
|
|
}
|
|
if uniqueKeyColumns.Len() == 0 {
|
|
return result, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLDeleteQuery")
|
|
}
|
|
for _, column := range uniqueKeyColumns.Columns() {
|
|
tableOrdinal := tableColumns.Ordinals[column.Name]
|
|
arg := column.convertArg(args[tableOrdinal])
|
|
uniqueKeyArgs = append(uniqueKeyArgs, arg)
|
|
}
|
|
databaseName = EscapeName(databaseName)
|
|
tableName = EscapeName(tableName)
|
|
equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names())
|
|
if err != nil {
|
|
return result, uniqueKeyArgs, err
|
|
}
|
|
result = fmt.Sprintf(`
|
|
delete /* gh-ost %s.%s */
|
|
from
|
|
%s.%s
|
|
where
|
|
%s
|
|
`, databaseName, tableName,
|
|
databaseName, tableName,
|
|
equalsComparison,
|
|
)
|
|
return result, uniqueKeyArgs, nil
|
|
}
|
|
|
|
func BuildDMLInsertQuery(databaseName, tableName string, tableColumns, sharedColumns, mappedSharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) {
|
|
if len(args) != tableColumns.Len() {
|
|
return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery")
|
|
}
|
|
if !sharedColumns.IsSubsetOf(tableColumns) {
|
|
return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery")
|
|
}
|
|
if sharedColumns.Len() == 0 {
|
|
return result, args, fmt.Errorf("No shared columns found in BuildDMLInsertQuery")
|
|
}
|
|
databaseName = EscapeName(databaseName)
|
|
tableName = EscapeName(tableName)
|
|
|
|
for _, column := range sharedColumns.Columns() {
|
|
tableOrdinal := tableColumns.Ordinals[column.Name]
|
|
arg := column.convertArg(args[tableOrdinal])
|
|
sharedArgs = append(sharedArgs, arg)
|
|
}
|
|
|
|
mappedSharedColumnNames := duplicateNames(mappedSharedColumns.Names())
|
|
for i := range mappedSharedColumnNames {
|
|
mappedSharedColumnNames[i] = EscapeName(mappedSharedColumnNames[i])
|
|
}
|
|
preparedValues := buildColumnsPreparedValues(mappedSharedColumns)
|
|
|
|
result = fmt.Sprintf(`
|
|
replace /* gh-ost %s.%s */ into
|
|
%s.%s
|
|
(%s)
|
|
values
|
|
(%s)
|
|
`, databaseName, tableName,
|
|
databaseName, tableName,
|
|
strings.Join(mappedSharedColumnNames, ", "),
|
|
strings.Join(preparedValues, ", "),
|
|
)
|
|
return result, sharedArgs, nil
|
|
}
|
|
|
|
func BuildDMLUpdateQuery(databaseName, tableName string, tableColumns, sharedColumns, mappedSharedColumns, uniqueKeyColumns *ColumnList, valueArgs, whereArgs []interface{}) (result string, sharedArgs, uniqueKeyArgs []interface{}, err error) {
|
|
if len(valueArgs) != tableColumns.Len() {
|
|
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("value args count differs from table column count in BuildDMLUpdateQuery")
|
|
}
|
|
if len(whereArgs) != tableColumns.Len() {
|
|
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("where args count differs from table column count in BuildDMLUpdateQuery")
|
|
}
|
|
if !sharedColumns.IsSubsetOf(tableColumns) {
|
|
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLUpdateQuery")
|
|
}
|
|
if !uniqueKeyColumns.IsSubsetOf(sharedColumns) {
|
|
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("unique key columns is not a subset of shared columns in BuildDMLUpdateQuery")
|
|
}
|
|
if sharedColumns.Len() == 0 {
|
|
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No shared columns found in BuildDMLUpdateQuery")
|
|
}
|
|
if uniqueKeyColumns.Len() == 0 {
|
|
return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLUpdateQuery")
|
|
}
|
|
databaseName = EscapeName(databaseName)
|
|
tableName = EscapeName(tableName)
|
|
|
|
for _, column := range sharedColumns.Columns() {
|
|
tableOrdinal := tableColumns.Ordinals[column.Name]
|
|
arg := column.convertArg(valueArgs[tableOrdinal])
|
|
sharedArgs = append(sharedArgs, arg)
|
|
}
|
|
|
|
for _, column := range uniqueKeyColumns.Columns() {
|
|
tableOrdinal := tableColumns.Ordinals[column.Name]
|
|
arg := column.convertArg(whereArgs[tableOrdinal])
|
|
uniqueKeyArgs = append(uniqueKeyArgs, arg)
|
|
}
|
|
|
|
setClause, err := BuildSetPreparedClause(mappedSharedColumns)
|
|
|
|
equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names())
|
|
result = fmt.Sprintf(`
|
|
update /* gh-ost %s.%s */
|
|
%s.%s
|
|
set
|
|
%s
|
|
where
|
|
%s
|
|
`, databaseName, tableName,
|
|
databaseName, tableName,
|
|
setClause,
|
|
equalsComparison,
|
|
)
|
|
return result, sharedArgs, uniqueKeyArgs, nil
|
|
}
|