231 lines
5.2 KiB
Go
Raw Normal View History

2017-02-12 13:13:54 +02:00
package server
import (
"database/sql"
"flag"
"fmt"
"net"
"strings"
"testing"
"time"
_ "github.com/go-sql-driver/mysql"
"github.com/juju/errors"
. "github.com/pingcap/check"
mysql "github.com/siddontang/go-mysql/mysql"
)
var testAddr = flag.String("addr", "127.0.0.1:4000", "MySQL proxy server address")
var testUser = flag.String("user", "root", "MySQL user")
var testPassword = flag.String("pass", "", "MySQL password")
var testDB = flag.String("db", "test", "MySQL test database")
func Test(t *testing.T) {
TestingT(t)
}
type serverTestSuite struct {
db *sql.DB
l net.Listener
}
var _ = Suite(&serverTestSuite{})
type testHandler struct {
s *serverTestSuite
}
func (h *testHandler) UseDB(dbName string) error {
return nil
}
func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, error) {
ss := strings.Split(query, " ")
switch strings.ToLower(ss[0]) {
case "select":
var r *mysql.Resultset
var err error
//for handle go mysql driver select @@max_allowed_packet
if strings.Contains(strings.ToLower(query), "max_allowed_packet") {
r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{
[]interface{}{mysql.MaxPayloadLen},
}, binary)
} else {
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
[]interface{}{1, "hello world"},
}, binary)
}
if err != nil {
return nil, errors.Trace(err)
} else {
return &mysql.Result{0, 0, 0, r}, nil
}
case "insert":
return &mysql.Result{0, 1, 0, nil}, nil
case "delete":
return &mysql.Result{0, 0, 1, nil}, nil
case "update":
return &mysql.Result{0, 0, 1, nil}, nil
case "replace":
return &mysql.Result{0, 0, 1, nil}, nil
default:
return nil, fmt.Errorf("invalid query %s", query)
}
return nil, nil
}
func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) {
return h.handleQuery(query, false)
}
func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {
return nil, nil
}
func (h *testHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) {
ss := strings.Split(sql, " ")
switch strings.ToLower(ss[0]) {
case "select":
params = 1
columns = 2
case "insert":
params = 2
columns = 0
case "replace":
params = 2
columns = 0
case "update":
params = 1
columns = 0
case "delete":
params = 1
columns = 0
default:
err = fmt.Errorf("invalid prepare %s", sql)
}
return params, columns, nil, err
}
func (h *testHandler) HandleStmtClose(context interface{}) error {
return nil
}
func (h *testHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) {
return h.handleQuery(query, true)
}
func (s *serverTestSuite) SetUpSuite(c *C) {
var err error
s.l, err = net.Listen("tcp", *testAddr)
c.Assert(err, IsNil)
go s.onAccept(c)
time.Sleep(500 * time.Millisecond)
s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", *testUser, *testPassword, *testAddr, *testDB))
c.Assert(err, IsNil)
s.db.SetMaxIdleConns(4)
}
func (s *serverTestSuite) TearDownSuite(c *C) {
if s.db != nil {
s.db.Close()
}
if s.l != nil {
s.l.Close()
}
}
func (s *serverTestSuite) onAccept(c *C) {
for {
conn, err := s.l.Accept()
if err != nil {
return
}
go s.onConn(conn, c)
}
}
func (s *serverTestSuite) onConn(conn net.Conn, c *C) {
co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s})
c.Assert(err, IsNil)
for {
err = co.HandleCommand()
if err != nil {
return
}
}
}
func (s *serverTestSuite) TestSelect(c *C) {
var a int64
var b string
err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=1").Scan(&a, &b)
c.Assert(err, IsNil)
c.Assert(a, Equals, int64(1))
c.Assert(b, Equals, "hello world")
}
func (s *serverTestSuite) TestExec(c *C) {
r, err := s.db.Exec("INSERT INTO tbl (a, b) values (1, \"hello world\")")
c.Assert(err, IsNil)
i, _ := r.LastInsertId()
c.Assert(i, Equals, int64(1))
r, err = s.db.Exec("REPLACE INTO tbl (a, b) values (1, \"hello world\")")
c.Assert(err, IsNil)
i, _ = r.RowsAffected()
c.Assert(i, Equals, int64(1))
r, err = s.db.Exec("UPDATE tbl SET b = \"abc\" where a = 1")
c.Assert(err, IsNil)
i, _ = r.RowsAffected()
c.Assert(i, Equals, int64(1))
r, err = s.db.Exec("DELETE FROM tbl where a = 1")
c.Assert(err, IsNil)
i, _ = r.RowsAffected()
c.Assert(i, Equals, int64(1))
}
func (s *serverTestSuite) TestStmtSelect(c *C) {
var a int64
var b string
err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=?", 1).Scan(&a, &b)
c.Assert(err, IsNil)
c.Assert(a, Equals, int64(1))
c.Assert(b, Equals, "hello world")
}
func (s *serverTestSuite) TestStmtExec(c *C) {
r, err := s.db.Exec("INSERT INTO tbl (a, b) values (?, ?)", 1, "hello world")
c.Assert(err, IsNil)
i, _ := r.LastInsertId()
c.Assert(i, Equals, int64(1))
r, err = s.db.Exec("REPLACE INTO tbl (a, b) values (?, ?)", 1, "hello world")
c.Assert(err, IsNil)
i, _ = r.RowsAffected()
c.Assert(i, Equals, int64(1))
r, err = s.db.Exec("UPDATE tbl SET b = \"abc\" where a = ?", 1)
c.Assert(err, IsNil)
i, _ = r.RowsAffected()
c.Assert(i, Equals, int64(1))
r, err = s.db.Exec("DELETE FROM tbl where a = ?", 1)
c.Assert(err, IsNil)
i, _ = r.RowsAffected()
c.Assert(i, Equals, int64(1))
}