231 lines
5.2 KiB
Go
231 lines
5.2 KiB
Go
|
package server
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
_ "github.com/go-sql-driver/mysql"
|
||
|
"github.com/juju/errors"
|
||
|
. "github.com/pingcap/check"
|
||
|
mysql "github.com/siddontang/go-mysql/mysql"
|
||
|
)
|
||
|
|
||
|
var testAddr = flag.String("addr", "127.0.0.1:4000", "MySQL proxy server address")
|
||
|
var testUser = flag.String("user", "root", "MySQL user")
|
||
|
var testPassword = flag.String("pass", "", "MySQL password")
|
||
|
var testDB = flag.String("db", "test", "MySQL test database")
|
||
|
|
||
|
func Test(t *testing.T) {
|
||
|
TestingT(t)
|
||
|
}
|
||
|
|
||
|
type serverTestSuite struct {
|
||
|
db *sql.DB
|
||
|
|
||
|
l net.Listener
|
||
|
}
|
||
|
|
||
|
var _ = Suite(&serverTestSuite{})
|
||
|
|
||
|
type testHandler struct {
|
||
|
s *serverTestSuite
|
||
|
}
|
||
|
|
||
|
func (h *testHandler) UseDB(dbName string) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, error) {
|
||
|
ss := strings.Split(query, " ")
|
||
|
switch strings.ToLower(ss[0]) {
|
||
|
case "select":
|
||
|
var r *mysql.Resultset
|
||
|
var err error
|
||
|
//for handle go mysql driver select @@max_allowed_packet
|
||
|
if strings.Contains(strings.ToLower(query), "max_allowed_packet") {
|
||
|
r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{
|
||
|
[]interface{}{mysql.MaxPayloadLen},
|
||
|
}, binary)
|
||
|
} else {
|
||
|
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
|
||
|
[]interface{}{1, "hello world"},
|
||
|
}, binary)
|
||
|
}
|
||
|
|
||
|
if err != nil {
|
||
|
return nil, errors.Trace(err)
|
||
|
} else {
|
||
|
return &mysql.Result{0, 0, 0, r}, nil
|
||
|
}
|
||
|
case "insert":
|
||
|
return &mysql.Result{0, 1, 0, nil}, nil
|
||
|
case "delete":
|
||
|
return &mysql.Result{0, 0, 1, nil}, nil
|
||
|
case "update":
|
||
|
return &mysql.Result{0, 0, 1, nil}, nil
|
||
|
case "replace":
|
||
|
return &mysql.Result{0, 0, 1, nil}, nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("invalid query %s", query)
|
||
|
}
|
||
|
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) {
|
||
|
return h.handleQuery(query, false)
|
||
|
}
|
||
|
|
||
|
func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
func (h *testHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) {
|
||
|
ss := strings.Split(sql, " ")
|
||
|
switch strings.ToLower(ss[0]) {
|
||
|
case "select":
|
||
|
params = 1
|
||
|
columns = 2
|
||
|
case "insert":
|
||
|
params = 2
|
||
|
columns = 0
|
||
|
case "replace":
|
||
|
params = 2
|
||
|
columns = 0
|
||
|
case "update":
|
||
|
params = 1
|
||
|
columns = 0
|
||
|
case "delete":
|
||
|
params = 1
|
||
|
columns = 0
|
||
|
default:
|
||
|
err = fmt.Errorf("invalid prepare %s", sql)
|
||
|
}
|
||
|
return params, columns, nil, err
|
||
|
}
|
||
|
|
||
|
func (h *testHandler) HandleStmtClose(context interface{}) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (h *testHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) {
|
||
|
return h.handleQuery(query, true)
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) SetUpSuite(c *C) {
|
||
|
var err error
|
||
|
|
||
|
s.l, err = net.Listen("tcp", *testAddr)
|
||
|
c.Assert(err, IsNil)
|
||
|
|
||
|
go s.onAccept(c)
|
||
|
|
||
|
time.Sleep(500 * time.Millisecond)
|
||
|
|
||
|
s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", *testUser, *testPassword, *testAddr, *testDB))
|
||
|
c.Assert(err, IsNil)
|
||
|
|
||
|
s.db.SetMaxIdleConns(4)
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) TearDownSuite(c *C) {
|
||
|
if s.db != nil {
|
||
|
s.db.Close()
|
||
|
}
|
||
|
|
||
|
if s.l != nil {
|
||
|
s.l.Close()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) onAccept(c *C) {
|
||
|
for {
|
||
|
conn, err := s.l.Accept()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
go s.onConn(conn, c)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) onConn(conn net.Conn, c *C) {
|
||
|
co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s})
|
||
|
c.Assert(err, IsNil)
|
||
|
|
||
|
for {
|
||
|
err = co.HandleCommand()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) TestSelect(c *C) {
|
||
|
var a int64
|
||
|
var b string
|
||
|
|
||
|
err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=1").Scan(&a, &b)
|
||
|
c.Assert(err, IsNil)
|
||
|
c.Assert(a, Equals, int64(1))
|
||
|
c.Assert(b, Equals, "hello world")
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) TestExec(c *C) {
|
||
|
r, err := s.db.Exec("INSERT INTO tbl (a, b) values (1, \"hello world\")")
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ := r.LastInsertId()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
|
||
|
r, err = s.db.Exec("REPLACE INTO tbl (a, b) values (1, \"hello world\")")
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ = r.RowsAffected()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
|
||
|
r, err = s.db.Exec("UPDATE tbl SET b = \"abc\" where a = 1")
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ = r.RowsAffected()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
|
||
|
r, err = s.db.Exec("DELETE FROM tbl where a = 1")
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ = r.RowsAffected()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) TestStmtSelect(c *C) {
|
||
|
var a int64
|
||
|
var b string
|
||
|
|
||
|
err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=?", 1).Scan(&a, &b)
|
||
|
c.Assert(err, IsNil)
|
||
|
c.Assert(a, Equals, int64(1))
|
||
|
c.Assert(b, Equals, "hello world")
|
||
|
}
|
||
|
|
||
|
func (s *serverTestSuite) TestStmtExec(c *C) {
|
||
|
r, err := s.db.Exec("INSERT INTO tbl (a, b) values (?, ?)", 1, "hello world")
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ := r.LastInsertId()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
|
||
|
r, err = s.db.Exec("REPLACE INTO tbl (a, b) values (?, ?)", 1, "hello world")
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ = r.RowsAffected()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
|
||
|
r, err = s.db.Exec("UPDATE tbl SET b = \"abc\" where a = ?", 1)
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ = r.RowsAffected()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
|
||
|
r, err = s.db.Exec("DELETE FROM tbl where a = ?", 1)
|
||
|
c.Assert(err, IsNil)
|
||
|
i, _ = r.RowsAffected()
|
||
|
c.Assert(i, Equals, int64(1))
|
||
|
}
|