gh-ost/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go
Tim Vaillancourt c71dbf9ef3
Copy auto increment (#967)
* v1.1.0

* WIP: copying AUTO_INCREMENT value to ghost table
Initial commit: towards setting up a test suite

Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com>

* greping for 'expect_table_structure' content

* Adding simple test for 'expect_table_structure' scenario

* adding tests for AUTO_INCREMENT value after row deletes. Should initially fail

* clear event beforehand

* parsing AUTO_INCREMENT from alter query, reading AUTO_INCREMENT from original table, applying AUTO_INCREMENT value onto ghost table if applicable and user has not specified AUTO_INCREMENT in alter statement

* support GetUint64

Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com>

* minor update to test

Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com>

* adding test for user defined AUTO_INCREMENT statement

Co-authored-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com>
2021-05-14 15:32:56 +02:00

347 lines
9.4 KiB
Go

/*
Copyright 2014 Outbrain Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlutils
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/outbrain/golib/log"
"strconv"
"strings"
"sync"
)
// RowMap represents one row in a result set. Its objective is to allow
// for easy, typed getters by column name.
type RowMap map[string]CellData
// Cell data is the result of a single (atomic) column in a single row
type CellData sql.NullString
func (this *CellData) MarshalJSON() ([]byte, error) {
if this.Valid {
return json.Marshal(this.String)
} else {
return json.Marshal(nil)
}
}
func (this *CellData) NullString() *sql.NullString {
return (*sql.NullString)(this)
}
// RowData is the result of a single row, in positioned array format
type RowData []CellData
// MarshalJSON will marshal this map as JSON
func (this *RowData) MarshalJSON() ([]byte, error) {
cells := make([](*CellData), len(*this), len(*this))
for i, val := range *this {
d := CellData(val)
cells[i] = &d
}
return json.Marshal(cells)
}
// ResultData is an ordered row set of RowData
type ResultData []RowData
var EmptyResultData = ResultData{}
func (this *RowMap) GetString(key string) string {
return (*this)[key].String
}
// GetStringD returns a string from the map, or a default value if the key does not exist
func (this *RowMap) GetStringD(key string, def string) string {
if cell, ok := (*this)[key]; ok {
return cell.String
}
return def
}
func (this *RowMap) GetInt64(key string) int64 {
res, _ := strconv.ParseInt(this.GetString(key), 10, 0)
return res
}
func (this *RowMap) GetNullInt64(key string) sql.NullInt64 {
i, err := strconv.ParseInt(this.GetString(key), 10, 0)
if err == nil {
return sql.NullInt64{Int64: i, Valid: true}
} else {
return sql.NullInt64{Valid: false}
}
}
func (this *RowMap) GetInt(key string) int {
res, _ := strconv.Atoi(this.GetString(key))
return res
}
func (this *RowMap) GetIntD(key string, def int) int {
res, err := strconv.Atoi(this.GetString(key))
if err != nil {
return def
}
return res
}
func (this *RowMap) GetUint(key string) uint {
res, _ := strconv.Atoi(this.GetString(key))
return uint(res)
}
func (this *RowMap) GetUintD(key string, def uint) uint {
res, err := strconv.Atoi(this.GetString(key))
if err != nil {
return def
}
return uint(res)
}
func (this *RowMap) GetUint64(key string) uint64 {
res, _ := strconv.ParseUint(this.GetString(key), 10, 0)
return res
}
func (this *RowMap) GetUint64D(key string, def uint64) uint64 {
res, err := strconv.ParseUint(this.GetString(key), 10, 0)
if err != nil {
return def
}
return uint64(res)
}
func (this *RowMap) GetBool(key string) bool {
return this.GetInt(key) != 0
}
// knownDBs is a DB cache by uri
var knownDBs map[string]*sql.DB = make(map[string]*sql.DB)
var knownDBsMutex = &sync.Mutex{}
// GetDB returns a DB instance based on uri.
// bool result indicates whether the DB was returned from cache; err
func GetDB(mysql_uri string) (*sql.DB, bool, error) {
knownDBsMutex.Lock()
defer func() {
knownDBsMutex.Unlock()
}()
var exists bool
if _, exists = knownDBs[mysql_uri]; !exists {
if db, err := sql.Open("mysql", mysql_uri); err == nil {
knownDBs[mysql_uri] = db
} else {
return db, exists, err
}
}
return knownDBs[mysql_uri], exists, nil
}
// RowToArray is a convenience function, typically not called directly, which maps a
// single read database row into a NullString
func RowToArray(rows *sql.Rows, columns []string) []CellData {
buff := make([]interface{}, len(columns))
data := make([]CellData, len(columns))
for i, _ := range buff {
buff[i] = data[i].NullString()
}
rows.Scan(buff...)
return data
}
// ScanRowsToArrays is a convenience function, typically not called directly, which maps rows
// already read from the databse into arrays of NullString
func ScanRowsToArrays(rows *sql.Rows, on_row func([]CellData) error) error {
columns, _ := rows.Columns()
for rows.Next() {
arr := RowToArray(rows, columns)
err := on_row(arr)
if err != nil {
return err
}
}
return nil
}
func rowToMap(row []CellData, columns []string) map[string]CellData {
m := make(map[string]CellData)
for k, data_col := range row {
m[columns[k]] = data_col
}
return m
}
// ScanRowsToMaps is a convenience function, typically not called directly, which maps rows
// already read from the databse into RowMap entries.
func ScanRowsToMaps(rows *sql.Rows, on_row func(RowMap) error) error {
columns, _ := rows.Columns()
err := ScanRowsToArrays(rows, func(arr []CellData) error {
m := rowToMap(arr, columns)
err := on_row(m)
if err != nil {
return err
}
return nil
})
return err
}
// QueryRowsMap is a convenience function allowing querying a result set while poviding a callback
// function activated per read row.
func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) error {
var err error
defer func() {
if derr := recover(); derr != nil {
err = errors.New(fmt.Sprintf("QueryRowsMap unexpected error: %+v", derr))
}
}()
rows, err := db.Query(query, args...)
defer rows.Close()
if err != nil && err != sql.ErrNoRows {
return log.Errore(err)
}
err = ScanRowsToMaps(rows, on_row)
return err
}
// queryResultData returns a raw array of rows for a given query, optionally reading and returning column names
func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...interface{}) (ResultData, []string, error) {
var err error
defer func() {
if derr := recover(); derr != nil {
err = errors.New(fmt.Sprintf("QueryRowsMap unexpected error: %+v", derr))
}
}()
columns := []string{}
rows, err := db.Query(query, args...)
defer rows.Close()
if err != nil && err != sql.ErrNoRows {
return EmptyResultData, columns, log.Errore(err)
}
if retrieveColumns {
// Don't pay if you don't want to
columns, _ = rows.Columns()
}
resultData := ResultData{}
err = ScanRowsToArrays(rows, func(rowData []CellData) error {
resultData = append(resultData, rowData)
return nil
})
return resultData, columns, err
}
// QueryResultData returns a raw array of rows
func QueryResultData(db *sql.DB, query string, args ...interface{}) (ResultData, error) {
resultData, _, err := queryResultData(db, query, false, args...)
return resultData, err
}
// QueryResultDataNamed returns a raw array of rows, with column names
func QueryResultDataNamed(db *sql.DB, query string, args ...interface{}) (ResultData, []string, error) {
return queryResultData(db, query, true, args...)
}
// QueryRowsMapBuffered reads data from the database into a buffer, and only then applies the given function per row.
// This allows the application to take its time with processing the data, albeit consuming as much memory as required by
// the result set.
func QueryRowsMapBuffered(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) error {
resultData, columns, err := queryResultData(db, query, true, args...)
if err != nil {
// Already logged
return err
}
for _, row := range resultData {
err = on_row(rowToMap(row, columns))
if err != nil {
return err
}
}
return nil
}
// ExecNoPrepare executes given query using given args on given DB, without using prepared statements.
func ExecNoPrepare(db *sql.DB, query string, args ...interface{}) (sql.Result, error) {
var err error
defer func() {
if derr := recover(); derr != nil {
err = errors.New(fmt.Sprintf("ExecNoPrepare unexpected error: %+v", derr))
}
}()
var res sql.Result
res, err = db.Exec(query, args...)
if err != nil {
log.Errore(err)
}
return res, err
}
// ExecQuery executes given query using given args on given DB. It will safele prepare, execute and close
// the statement.
func execInternal(silent bool, db *sql.DB, query string, args ...interface{}) (sql.Result, error) {
var err error
defer func() {
if derr := recover(); derr != nil {
err = errors.New(fmt.Sprintf("execInternal unexpected error: %+v", derr))
}
}()
stmt, err := db.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
var res sql.Result
res, err = stmt.Exec(args...)
if err != nil && !silent {
log.Errore(err)
}
return res, err
}
// Exec executes given query using given args on given DB. It will safele prepare, execute and close
// the statement.
func Exec(db *sql.DB, query string, args ...interface{}) (sql.Result, error) {
return execInternal(false, db, query, args...)
}
// ExecSilently acts like Exec but does not report any error
func ExecSilently(db *sql.DB, query string, args ...interface{}) (sql.Result, error) {
return execInternal(true, db, query, args...)
}
func InClauseStringValues(terms []string) string {
quoted := []string{}
for _, s := range terms {
quoted = append(quoted, fmt.Sprintf("'%s'", strings.Replace(s, ",", "''", -1)))
}
return strings.Join(quoted, ", ")
}
// Convert variable length arguments into arguments array
func Args(args ...interface{}) []interface{} {
return args
}