gh-ost/vendor/github.com/siddontang/go-mysql/server/stmt.go

364 lines
7.0 KiB
Go

package server
import (
"encoding/binary"
"fmt"
"math"
"strconv"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
)
var paramFieldData []byte
var columnFieldData []byte
func init() {
var p = &Field{Name: []byte("?")}
var c = &Field{}
paramFieldData = p.Dump()
columnFieldData = c.Dump()
}
type Stmt struct {
ID uint32
Query string
Params int
Columns int
Args []interface{}
Context interface{}
}
func (s *Stmt) Rest(params int, columns int, context interface{}) {
s.Params = params
s.Columns = columns
s.Context = context
s.ResetParams()
}
func (s *Stmt) ResetParams() {
s.Args = make([]interface{}, s.Params)
}
func (c *Conn) writePrepare(s *Stmt) error {
data := make([]byte, 4, 128)
//status ok
data = append(data, 0)
//stmt id
data = append(data, Uint32ToBytes(s.ID)...)
//number columns
data = append(data, Uint16ToBytes(uint16(s.Columns))...)
//number params
data = append(data, Uint16ToBytes(uint16(s.Params))...)
//filter [00]
data = append(data, 0)
//warning count
data = append(data, 0, 0)
if err := c.WritePacket(data); err != nil {
return err
}
if s.Params > 0 {
for i := 0; i < s.Params; i++ {
data = data[0:4]
data = append(data, []byte(paramFieldData)...)
if err := c.WritePacket(data); err != nil {
return errors.Trace(err)
}
}
if err := c.writeEOF(); err != nil {
return err
}
}
if s.Columns > 0 {
for i := 0; i < s.Columns; i++ {
data = data[0:4]
data = append(data, []byte(columnFieldData)...)
if err := c.WritePacket(data); err != nil {
return errors.Trace(err)
}
}
if err := c.writeEOF(); err != nil {
return err
}
}
return nil
}
func (c *Conn) handleStmtExecute(data []byte) (*Result, error) {
if len(data) < 9 {
return nil, ErrMalformPacket
}
pos := 0
id := binary.LittleEndian.Uint32(data[0:4])
pos += 4
s, ok := c.stmts[id]
if !ok {
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
strconv.FormatUint(uint64(id), 10), "stmt_execute")
}
flag := data[pos]
pos++
//now we only support CURSOR_TYPE_NO_CURSOR flag
if flag != 0 {
return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag))
}
//skip iteration-count, always 1
pos += 4
var nullBitmaps []byte
var paramTypes []byte
var paramValues []byte
paramNum := s.Params
if paramNum > 0 {
nullBitmapLen := (s.Params + 7) >> 3
if len(data) < (pos + nullBitmapLen + 1) {
return nil, ErrMalformPacket
}
nullBitmaps = data[pos : pos+nullBitmapLen]
pos += nullBitmapLen
//new param bound flag
if data[pos] == 1 {
pos++
if len(data) < (pos + (paramNum << 1)) {
return nil, ErrMalformPacket
}
paramTypes = data[pos : pos+(paramNum<<1)]
pos += paramNum << 1
paramValues = data[pos:]
}
if err := c.bindStmtArgs(s, nullBitmaps, paramTypes, paramValues); err != nil {
return nil, errors.Trace(err)
}
}
var r *Result
var err error
if r, err = c.h.HandleStmtExecute(s.Context, s.Query, s.Args); err != nil {
return nil, errors.Trace(err)
}
s.ResetParams()
return r, nil
}
func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error {
args := s.Args
pos := 0
var v []byte
var n int = 0
var isNull bool
var err error
for i := 0; i < s.Params; i++ {
if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
args[i] = nil
continue
}
tp := paramTypes[i<<1]
isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
switch tp {
case MYSQL_TYPE_NULL:
args[i] = nil
continue
case MYSQL_TYPE_TINY:
if len(paramValues) < (pos + 1) {
return ErrMalformPacket
}
if isUnsigned {
args[i] = uint8(paramValues[pos])
} else {
args[i] = int8(paramValues[pos])
}
pos++
continue
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
if len(paramValues) < (pos + 2) {
return ErrMalformPacket
}
if isUnsigned {
args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2]))
} else {
args[i] = int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2]))
}
pos += 2
continue
case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG:
if len(paramValues) < (pos + 4) {
return ErrMalformPacket
}
if isUnsigned {
args[i] = uint32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
} else {
args[i] = int32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
}
pos += 4
continue
case MYSQL_TYPE_LONGLONG:
if len(paramValues) < (pos + 8) {
return ErrMalformPacket
}
if isUnsigned {
args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8])
} else {
args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
}
pos += 8
continue
case MYSQL_TYPE_FLOAT:
if len(paramValues) < (pos + 4) {
return ErrMalformPacket
}
args[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])))
pos += 4
continue
case MYSQL_TYPE_DOUBLE:
if len(paramValues) < (pos + 8) {
return ErrMalformPacket
}
args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
pos += 8
continue
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR,
MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB,
MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY,
MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE,
MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME:
if len(paramValues) < (pos + 1) {
return ErrMalformPacket
}
v, isNull, n, err = LengthEncodedString(paramValues[pos:])
pos += n
if err != nil {
return errors.Trace(err)
}
if !isNull {
args[i] = v
continue
} else {
args[i] = nil
continue
}
default:
return errors.Errorf("Stmt Unknown FieldType %d", tp)
}
}
return nil
}
// stmt send long data command has no response
func (c *Conn) handleStmtSendLongData(data []byte) error {
if len(data) < 6 {
return nil
}
id := binary.LittleEndian.Uint32(data[0:4])
s, ok := c.stmts[id]
if !ok {
return nil
}
paramId := binary.LittleEndian.Uint16(data[4:6])
if paramId >= uint16(s.Params) {
return nil
}
if s.Args[paramId] == nil {
s.Args[paramId] = data[6:]
} else {
if b, ok := s.Args[paramId].([]byte); ok {
b = append(b, data[6:]...)
s.Args[paramId] = b
} else {
return nil
}
}
return nil
}
func (c *Conn) handleStmtReset(data []byte) (*Result, error) {
if len(data) < 4 {
return nil, ErrMalformPacket
}
id := binary.LittleEndian.Uint32(data[0:4])
s, ok := c.stmts[id]
if !ok {
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
strconv.FormatUint(uint64(id), 10), "stmt_reset")
}
s.ResetParams()
return &Result{}, nil
}
// stmt close command has no response
func (c *Conn) handleStmtClose(data []byte) error {
if len(data) < 4 {
return nil
}
id := binary.LittleEndian.Uint32(data[0:4])
stmt, ok := c.stmts[id]
if !ok {
return nil
}
if err := c.h.HandleStmtClose(stmt.Context); err != nil {
return err
}
delete(c.stmts, id)
return nil
}