2021-07-15 21:49:50 +02:00

215 lines
4.7 KiB
Go

package client
import (
"encoding/binary"
"fmt"
"math"
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/errors"
)
type Stmt struct {
conn *Conn
id uint32
params int
columns int
}
func (s *Stmt) ParamNum() int {
return s.params
}
func (s *Stmt) ColumnNum() int {
return s.columns
}
func (s *Stmt) Execute(args ...interface{}) (*Result, error) {
if err := s.write(args...); err != nil {
return nil, errors.Trace(err)
}
return s.conn.readResult(true)
}
func (s *Stmt) Close() error {
if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil {
return errors.Trace(err)
}
return nil
}
func (s *Stmt) write(args ...interface{}) error {
paramsNum := s.params
if len(args) != paramsNum {
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
}
paramTypes := make([]byte, paramsNum<<1)
paramValues := make([][]byte, paramsNum)
//NULL-bitmap, length: (num-params+7)
nullBitmap := make([]byte, (paramsNum+7)>>3)
length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)
var newParamBoundFlag byte = 0
for i := range args {
if args[i] == nil {
nullBitmap[i/8] |= (1 << (uint(i) % 8))
paramTypes[i<<1] = MYSQL_TYPE_NULL
continue
}
newParamBoundFlag = 1
switch v := args[i].(type) {
case int8:
paramTypes[i<<1] = MYSQL_TYPE_TINY
paramValues[i] = []byte{byte(v)}
case int16:
paramTypes[i<<1] = MYSQL_TYPE_SHORT
paramValues[i] = Uint16ToBytes(uint16(v))
case int32:
paramTypes[i<<1] = MYSQL_TYPE_LONG
paramValues[i] = Uint32ToBytes(uint32(v))
case int:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramValues[i] = Uint64ToBytes(uint64(v))
case int64:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramValues[i] = Uint64ToBytes(uint64(v))
case uint8:
paramTypes[i<<1] = MYSQL_TYPE_TINY
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = []byte{v}
case uint16:
paramTypes[i<<1] = MYSQL_TYPE_SHORT
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint16ToBytes(v)
case uint32:
paramTypes[i<<1] = MYSQL_TYPE_LONG
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint32ToBytes(v)
case uint:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint64ToBytes(uint64(v))
case uint64:
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
paramTypes[(i<<1)+1] = 0x80
paramValues[i] = Uint64ToBytes(v)
case bool:
paramTypes[i<<1] = MYSQL_TYPE_TINY
if v {
paramValues[i] = []byte{1}
} else {
paramValues[i] = []byte{0}
}
case float32:
paramTypes[i<<1] = MYSQL_TYPE_FLOAT
paramValues[i] = Uint32ToBytes(math.Float32bits(v))
case float64:
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
paramValues[i] = Uint64ToBytes(math.Float64bits(v))
case string:
paramTypes[i<<1] = MYSQL_TYPE_STRING
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
case []byte:
paramTypes[i<<1] = MYSQL_TYPE_STRING
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
default:
return fmt.Errorf("invalid argument type %T", args[i])
}
length += len(paramValues[i])
}
data := make([]byte, 4, 4+length)
data = append(data, COM_STMT_EXECUTE)
data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24))
//flag: CURSOR_TYPE_NO_CURSOR
data = append(data, 0x00)
//iteration-count, always 1
data = append(data, 1, 0, 0, 0)
if s.params > 0 {
data = append(data, nullBitmap...)
//new-params-bound-flag
data = append(data, newParamBoundFlag)
if newParamBoundFlag == 1 {
//type of each parameter, length: num-params * 2
data = append(data, paramTypes...)
//value of each parameter
for _, v := range paramValues {
data = append(data, v...)
}
}
}
s.conn.ResetSequence()
return s.conn.WritePacket(data)
}
func (c *Conn) Prepare(query string) (*Stmt, error) {
if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil {
return nil, errors.Trace(err)
}
data, err := c.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
if data[0] == ERR_HEADER {
return nil, c.handleErrorPacket(data)
} else if data[0] != OK_HEADER {
return nil, ErrMalformPacket
}
s := new(Stmt)
s.conn = c
pos := 1
//for statement id
s.id = binary.LittleEndian.Uint32(data[pos:])
pos += 4
//number columns
s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
pos += 2
//number params
s.params = int(binary.LittleEndian.Uint16(data[pos:]))
pos += 2
//warnings
//warnings = binary.LittleEndian.Uint16(data[pos:])
if s.params > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
}
}
if s.columns > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
}
}
return s, nil
}