364 lines
7.0 KiB
Go
364 lines
7.0 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"math"
|
|
"strconv"
|
|
|
|
"github.com/juju/errors"
|
|
. "github.com/siddontang/go-mysql/mysql"
|
|
)
|
|
|
|
var paramFieldData []byte
|
|
var columnFieldData []byte
|
|
|
|
func init() {
|
|
var p = &Field{Name: []byte("?")}
|
|
var c = &Field{}
|
|
|
|
paramFieldData = p.Dump()
|
|
columnFieldData = c.Dump()
|
|
}
|
|
|
|
type Stmt struct {
|
|
ID uint32
|
|
Query string
|
|
|
|
Params int
|
|
Columns int
|
|
|
|
Args []interface{}
|
|
|
|
Context interface{}
|
|
}
|
|
|
|
func (s *Stmt) Rest(params int, columns int, context interface{}) {
|
|
s.Params = params
|
|
s.Columns = columns
|
|
s.Context = context
|
|
s.ResetParams()
|
|
}
|
|
|
|
func (s *Stmt) ResetParams() {
|
|
s.Args = make([]interface{}, s.Params)
|
|
}
|
|
|
|
func (c *Conn) writePrepare(s *Stmt) error {
|
|
data := make([]byte, 4, 128)
|
|
|
|
//status ok
|
|
data = append(data, 0)
|
|
//stmt id
|
|
data = append(data, Uint32ToBytes(s.ID)...)
|
|
//number columns
|
|
data = append(data, Uint16ToBytes(uint16(s.Columns))...)
|
|
//number params
|
|
data = append(data, Uint16ToBytes(uint16(s.Params))...)
|
|
//filter [00]
|
|
data = append(data, 0)
|
|
//warning count
|
|
data = append(data, 0, 0)
|
|
|
|
if err := c.WritePacket(data); err != nil {
|
|
return err
|
|
}
|
|
|
|
if s.Params > 0 {
|
|
for i := 0; i < s.Params; i++ {
|
|
data = data[0:4]
|
|
data = append(data, []byte(paramFieldData)...)
|
|
|
|
if err := c.WritePacket(data); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
}
|
|
|
|
if err := c.writeEOF(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if s.Columns > 0 {
|
|
for i := 0; i < s.Columns; i++ {
|
|
data = data[0:4]
|
|
data = append(data, []byte(columnFieldData)...)
|
|
|
|
if err := c.WritePacket(data); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
}
|
|
|
|
if err := c.writeEOF(); err != nil {
|
|
return err
|
|
}
|
|
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) handleStmtExecute(data []byte) (*Result, error) {
|
|
if len(data) < 9 {
|
|
return nil, ErrMalformPacket
|
|
}
|
|
|
|
pos := 0
|
|
id := binary.LittleEndian.Uint32(data[0:4])
|
|
pos += 4
|
|
|
|
s, ok := c.stmts[id]
|
|
if !ok {
|
|
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
|
|
strconv.FormatUint(uint64(id), 10), "stmt_execute")
|
|
}
|
|
|
|
flag := data[pos]
|
|
pos++
|
|
//now we only support CURSOR_TYPE_NO_CURSOR flag
|
|
if flag != 0 {
|
|
return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag))
|
|
}
|
|
|
|
//skip iteration-count, always 1
|
|
pos += 4
|
|
|
|
var nullBitmaps []byte
|
|
var paramTypes []byte
|
|
var paramValues []byte
|
|
|
|
paramNum := s.Params
|
|
|
|
if paramNum > 0 {
|
|
nullBitmapLen := (s.Params + 7) >> 3
|
|
if len(data) < (pos + nullBitmapLen + 1) {
|
|
return nil, ErrMalformPacket
|
|
}
|
|
nullBitmaps = data[pos : pos+nullBitmapLen]
|
|
pos += nullBitmapLen
|
|
|
|
//new param bound flag
|
|
if data[pos] == 1 {
|
|
pos++
|
|
if len(data) < (pos + (paramNum << 1)) {
|
|
return nil, ErrMalformPacket
|
|
}
|
|
|
|
paramTypes = data[pos : pos+(paramNum<<1)]
|
|
pos += paramNum << 1
|
|
|
|
paramValues = data[pos:]
|
|
}
|
|
|
|
if err := c.bindStmtArgs(s, nullBitmaps, paramTypes, paramValues); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
}
|
|
|
|
var r *Result
|
|
var err error
|
|
if r, err = c.h.HandleStmtExecute(s.Context, s.Query, s.Args); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
s.ResetParams()
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error {
|
|
args := s.Args
|
|
|
|
pos := 0
|
|
|
|
var v []byte
|
|
var n int = 0
|
|
var isNull bool
|
|
var err error
|
|
|
|
for i := 0; i < s.Params; i++ {
|
|
if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
|
|
args[i] = nil
|
|
continue
|
|
}
|
|
|
|
tp := paramTypes[i<<1]
|
|
isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
|
|
|
|
switch tp {
|
|
case MYSQL_TYPE_NULL:
|
|
args[i] = nil
|
|
continue
|
|
|
|
case MYSQL_TYPE_TINY:
|
|
if len(paramValues) < (pos + 1) {
|
|
return ErrMalformPacket
|
|
}
|
|
|
|
if isUnsigned {
|
|
args[i] = uint8(paramValues[pos])
|
|
} else {
|
|
args[i] = int8(paramValues[pos])
|
|
}
|
|
|
|
pos++
|
|
continue
|
|
|
|
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
|
|
if len(paramValues) < (pos + 2) {
|
|
return ErrMalformPacket
|
|
}
|
|
|
|
if isUnsigned {
|
|
args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2]))
|
|
} else {
|
|
args[i] = int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2]))
|
|
}
|
|
pos += 2
|
|
continue
|
|
|
|
case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG:
|
|
if len(paramValues) < (pos + 4) {
|
|
return ErrMalformPacket
|
|
}
|
|
|
|
if isUnsigned {
|
|
args[i] = uint32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
|
|
} else {
|
|
args[i] = int32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
|
|
}
|
|
pos += 4
|
|
continue
|
|
|
|
case MYSQL_TYPE_LONGLONG:
|
|
if len(paramValues) < (pos + 8) {
|
|
return ErrMalformPacket
|
|
}
|
|
|
|
if isUnsigned {
|
|
args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8])
|
|
} else {
|
|
args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
|
|
}
|
|
pos += 8
|
|
continue
|
|
|
|
case MYSQL_TYPE_FLOAT:
|
|
if len(paramValues) < (pos + 4) {
|
|
return ErrMalformPacket
|
|
}
|
|
|
|
args[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])))
|
|
pos += 4
|
|
continue
|
|
|
|
case MYSQL_TYPE_DOUBLE:
|
|
if len(paramValues) < (pos + 8) {
|
|
return ErrMalformPacket
|
|
}
|
|
|
|
args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
|
|
pos += 8
|
|
continue
|
|
|
|
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR,
|
|
MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB,
|
|
MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
|
|
MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY,
|
|
MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE,
|
|
MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME:
|
|
if len(paramValues) < (pos + 1) {
|
|
return ErrMalformPacket
|
|
}
|
|
|
|
v, isNull, n, err = LengthEncodedString(paramValues[pos:])
|
|
pos += n
|
|
if err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
if !isNull {
|
|
args[i] = v
|
|
continue
|
|
} else {
|
|
args[i] = nil
|
|
continue
|
|
}
|
|
default:
|
|
return errors.Errorf("Stmt Unknown FieldType %d", tp)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// stmt send long data command has no response
|
|
func (c *Conn) handleStmtSendLongData(data []byte) error {
|
|
if len(data) < 6 {
|
|
return nil
|
|
}
|
|
|
|
id := binary.LittleEndian.Uint32(data[0:4])
|
|
|
|
s, ok := c.stmts[id]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
paramId := binary.LittleEndian.Uint16(data[4:6])
|
|
if paramId >= uint16(s.Params) {
|
|
return nil
|
|
}
|
|
|
|
if s.Args[paramId] == nil {
|
|
s.Args[paramId] = data[6:]
|
|
} else {
|
|
if b, ok := s.Args[paramId].([]byte); ok {
|
|
b = append(b, data[6:]...)
|
|
s.Args[paramId] = b
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) handleStmtReset(data []byte) (*Result, error) {
|
|
if len(data) < 4 {
|
|
return nil, ErrMalformPacket
|
|
}
|
|
|
|
id := binary.LittleEndian.Uint32(data[0:4])
|
|
|
|
s, ok := c.stmts[id]
|
|
if !ok {
|
|
return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
|
|
strconv.FormatUint(uint64(id), 10), "stmt_reset")
|
|
}
|
|
|
|
s.ResetParams()
|
|
|
|
return &Result{}, nil
|
|
}
|
|
|
|
// stmt close command has no response
|
|
func (c *Conn) handleStmtClose(data []byte) error {
|
|
if len(data) < 4 {
|
|
return nil
|
|
}
|
|
|
|
id := binary.LittleEndian.Uint32(data[0:4])
|
|
|
|
stmt, ok := c.stmts[id]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if err := c.h.HandleStmtClose(stmt.Context); err != nil {
|
|
return err
|
|
}
|
|
|
|
delete(c.stmts, id)
|
|
|
|
return nil
|
|
}
|