2016-06-16 11:15:56 +02:00

401 lines
7.8 KiB
Go

package mysql
import (
"strconv"
"github.com/juju/errors"
"github.com/siddontang/go/hack"
)
type RowData []byte
func (p RowData) Parse(f []*Field, binary bool) ([]interface{}, error) {
if binary {
return p.ParseBinary(f)
} else {
return p.ParseText(f)
}
}
func (p RowData) ParseText(f []*Field) ([]interface{}, error) {
data := make([]interface{}, len(f))
var err error
var v []byte
var isNull bool
var pos int = 0
var n int = 0
for i := range f {
v, isNull, n, err = LengthEnodedString(p[pos:])
if err != nil {
return nil, errors.Trace(err)
}
pos += n
if isNull {
data[i] = nil
} else {
isUnsigned := f[i].Flag&UNSIGNED_FLAG != 0
switch f[i].Type {
case MYSQL_TYPE_TINY, MYSQL_TYPE_SHORT, MYSQL_TYPE_INT24,
MYSQL_TYPE_LONGLONG, MYSQL_TYPE_YEAR:
if isUnsigned {
data[i], err = strconv.ParseUint(string(v), 10, 64)
} else {
data[i], err = strconv.ParseInt(string(v), 10, 64)
}
case MYSQL_TYPE_FLOAT, MYSQL_TYPE_DOUBLE:
data[i], err = strconv.ParseFloat(string(v), 64)
default:
data[i] = v
}
if err != nil {
return nil, errors.Trace(err)
}
}
}
return data, nil
}
func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) {
data := make([]interface{}, len(f))
if p[0] != OK_HEADER {
return nil, ErrMalformPacket
}
pos := 1 + ((len(f) + 7 + 2) >> 3)
nullBitmap := p[1:pos]
var isNull bool
var n int
var err error
var v []byte
for i := range data {
if nullBitmap[(i+2)/8]&(1<<(uint(i+2)%8)) > 0 {
data[i] = nil
continue
}
isUnsigned := f[i].Flag&UNSIGNED_FLAG != 0
switch f[i].Type {
case MYSQL_TYPE_NULL:
data[i] = nil
continue
case MYSQL_TYPE_TINY:
if isUnsigned {
data[i] = ParseBinaryUint8(p[pos : pos+1])
} else {
data[i] = ParseBinaryInt8(p[pos : pos+1])
}
pos++
continue
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
if isUnsigned {
data[i] = ParseBinaryUint16(p[pos : pos+2])
} else {
data[i] = ParseBinaryInt16(p[pos : pos+2])
}
pos += 2
continue
case MYSQL_TYPE_INT24:
if isUnsigned {
data[i] = ParseBinaryUint24(p[pos : pos+3])
} else {
data[i] = ParseBinaryInt24(p[pos : pos+3])
}
pos += 4
continue
case MYSQL_TYPE_LONG:
if isUnsigned {
data[i] = ParseBinaryUint32(p[pos : pos+4])
} else {
data[i] = ParseBinaryInt32(p[pos : pos+4])
}
pos += 4
continue
case MYSQL_TYPE_LONGLONG:
if isUnsigned {
data[i] = ParseBinaryUint64(p[pos : pos+8])
} else {
data[i] = ParseBinaryInt64(p[pos : pos+8])
}
pos += 8
continue
case MYSQL_TYPE_FLOAT:
data[i] = ParseBinaryFloat32(p[pos : pos+4])
pos += 4
continue
case MYSQL_TYPE_DOUBLE:
data[i] = ParseBinaryFloat64(p[pos : pos+4])
pos += 8
continue
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR,
MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB,
MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY:
v, isNull, n, err = LengthEnodedString(p[pos:])
pos += n
if err != nil {
return nil, errors.Trace(err)
}
if !isNull {
data[i] = v
continue
} else {
data[i] = nil
continue
}
case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE:
var num uint64
num, isNull, n = LengthEncodedInt(p[pos:])
pos += n
if isNull {
data[i] = nil
continue
}
data[i], err = FormatBinaryDate(int(num), p[pos:])
pos += int(num)
if err != nil {
return nil, errors.Trace(err)
}
case MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME:
var num uint64
num, isNull, n = LengthEncodedInt(p[pos:])
pos += n
if isNull {
data[i] = nil
continue
}
data[i], err = FormatBinaryDateTime(int(num), p[pos:])
pos += int(num)
if err != nil {
return nil, errors.Trace(err)
}
case MYSQL_TYPE_TIME:
var num uint64
num, isNull, n = LengthEncodedInt(p[pos:])
pos += n
if isNull {
data[i] = nil
continue
}
data[i], err = FormatBinaryTime(int(num), p[pos:])
pos += int(num)
if err != nil {
return nil, errors.Trace(err)
}
default:
return nil, errors.Errorf("Stmt Unknown FieldType %d %s", f[i].Type, f[i].Name)
}
}
return data, nil
}
type Resultset struct {
Fields []*Field
FieldNames map[string]int
Values [][]interface{}
RowDatas []RowData
}
func (r *Resultset) RowNumber() int {
return len(r.Values)
}
func (r *Resultset) ColumnNumber() int {
return len(r.Fields)
}
func (r *Resultset) GetValue(row, column int) (interface{}, error) {
if row >= len(r.Values) || row < 0 {
return nil, errors.Errorf("invalid row index %d", row)
}
if column >= len(r.Fields) || column < 0 {
return nil, errors.Errorf("invalid column index %d", column)
}
return r.Values[row][column], nil
}
func (r *Resultset) NameIndex(name string) (int, error) {
if column, ok := r.FieldNames[name]; ok {
return column, nil
} else {
return 0, errors.Errorf("invalid field name %s", name)
}
}
func (r *Resultset) GetValueByName(row int, name string) (interface{}, error) {
if column, err := r.NameIndex(name); err != nil {
return nil, errors.Trace(err)
} else {
return r.GetValue(row, column)
}
}
func (r *Resultset) IsNull(row, column int) (bool, error) {
d, err := r.GetValue(row, column)
if err != nil {
return false, err
}
return d == nil, nil
}
func (r *Resultset) IsNullByName(row int, name string) (bool, error) {
if column, err := r.NameIndex(name); err != nil {
return false, err
} else {
return r.IsNull(row, column)
}
}
func (r *Resultset) GetUint(row, column int) (uint64, error) {
d, err := r.GetValue(row, column)
if err != nil {
return 0, err
}
switch v := d.(type) {
case uint64:
return v, nil
case int64:
return uint64(v), nil
case float64:
return uint64(v), nil
case string:
return strconv.ParseUint(v, 10, 64)
case []byte:
return strconv.ParseUint(string(v), 10, 64)
case nil:
return 0, nil
default:
return 0, errors.Errorf("data type is %T", v)
}
}
func (r *Resultset) GetUintByName(row int, name string) (uint64, error) {
if column, err := r.NameIndex(name); err != nil {
return 0, err
} else {
return r.GetUint(row, column)
}
}
func (r *Resultset) GetInt(row, column int) (int64, error) {
v, err := r.GetUint(row, column)
if err != nil {
return 0, err
}
return int64(v), nil
}
func (r *Resultset) GetIntByName(row int, name string) (int64, error) {
v, err := r.GetUintByName(row, name)
if err != nil {
return 0, err
}
return int64(v), nil
}
func (r *Resultset) GetFloat(row, column int) (float64, error) {
d, err := r.GetValue(row, column)
if err != nil {
return 0, err
}
switch v := d.(type) {
case float64:
return v, nil
case uint64:
return float64(v), nil
case int64:
return float64(v), nil
case string:
return strconv.ParseFloat(v, 64)
case []byte:
return strconv.ParseFloat(string(v), 64)
case nil:
return 0, nil
default:
return 0, errors.Errorf("data type is %T", v)
}
}
func (r *Resultset) GetFloatByName(row int, name string) (float64, error) {
if column, err := r.NameIndex(name); err != nil {
return 0, err
} else {
return r.GetFloat(row, column)
}
}
func (r *Resultset) GetString(row, column int) (string, error) {
d, err := r.GetValue(row, column)
if err != nil {
return "", err
}
switch v := d.(type) {
case string:
return v, nil
case []byte:
return hack.String(v), nil
case int64:
return strconv.FormatInt(v, 10), nil
case uint64:
return strconv.FormatUint(v, 10), nil
case float64:
return strconv.FormatFloat(v, 'f', -1, 64), nil
case nil:
return "", nil
default:
return "", errors.Errorf("data type is %T", v)
}
}
func (r *Resultset) GetStringByName(row int, name string) (string, error) {
if column, err := r.NameIndex(name); err != nil {
return "", err
} else {
return r.GetString(row, column)
}
}