gh-ost/vendor/github.com/go-mysql-org/go-mysql/client/conn.go

300 lines
6.0 KiB
Go
Raw Normal View History

2016-06-16 09:15:56 +00:00
package client
import (
2017-02-12 11:13:54 +00:00
"crypto/tls"
2016-06-16 09:15:56 +00:00
"fmt"
"net"
"strings"
"time"
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
2016-06-16 09:15:56 +00:00
)
type Conn struct {
*packet.Conn
2017-02-12 11:13:54 +00:00
user string
password string
db string
tlsConfig *tls.Config
proto string
2016-06-16 09:15:56 +00:00
capability uint32
status uint16
charset string
salt []byte
authPluginName string
2016-06-16 09:15:56 +00:00
connectionID uint32
}
// This function will be called for every row in resultset from ExecuteSelectStreaming.
type SelectPerRowCallback func(row []FieldValue) error
2016-06-16 09:15:56 +00:00
func getNetProto(addr string) string {
proto := "tcp"
if strings.Contains(addr, "/") {
proto = "unix"
}
return proto
}
// Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
2017-02-12 11:13:54 +00:00
// Accepts a series of configuration functions as a variadic argument.
func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
2016-06-16 09:15:56 +00:00
proto := getNetProto(addr)
c := new(Conn)
var err error
conn, err := net.DialTimeout(proto, addr, 10*time.Second)
if err != nil {
return nil, errors.Trace(err)
}
if c.tlsConfig != nil {
c.Conn = packet.NewTLSConn(conn)
} else {
c.Conn = packet.NewConn(conn)
}
2016-06-16 09:15:56 +00:00
c.user = user
c.password = password
c.db = dbName
c.proto = proto
2016-06-16 09:15:56 +00:00
//use default charset here, utf-8
c.charset = DEFAULT_CHARSET
2017-02-12 11:13:54 +00:00
// Apply configuration functions.
for i := range options {
options[i](c)
}
2016-06-16 09:15:56 +00:00
if err = c.handshake(); err != nil {
return nil, errors.Trace(err)
}
return c, nil
}
func (c *Conn) handshake() error {
var err error
if err = c.readInitialHandshake(); err != nil {
c.Close()
return errors.Trace(err)
}
if err := c.writeAuthHandshake(); err != nil {
c.Close()
return errors.Trace(err)
}
if err := c.handleAuthResult(); err != nil {
2016-06-16 09:15:56 +00:00
c.Close()
return errors.Trace(err)
}
return nil
}
func (c *Conn) Close() error {
return c.Conn.Close()
}
func (c *Conn) Ping() error {
if err := c.writeCommand(COM_PING); err != nil {
return errors.Trace(err)
}
if _, err := c.readOK(); err != nil {
return errors.Trace(err)
}
return nil
}
// UseSSL: use default SSL
// pass to options when connect
func (c *Conn) UseSSL(insecureSkipVerify bool) {
c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify}
}
// SetTLSConfig: use user-specified TLS config
// pass to options when connect
func (c *Conn) SetTLSConfig(config *tls.Config) {
c.tlsConfig = config
}
2016-06-16 09:15:56 +00:00
func (c *Conn) UseDB(dbName string) error {
if c.db == dbName {
return nil
}
if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil {
return errors.Trace(err)
}
if _, err := c.readOK(); err != nil {
return errors.Trace(err)
}
c.db = dbName
return nil
}
func (c *Conn) GetDB() string {
return c.db
}
func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
if len(args) == 0 {
return c.exec(command)
} else {
if s, err := c.Prepare(command); err != nil {
return nil, errors.Trace(err)
} else {
var r *Result
r, err = s.Execute(args...)
s.Close()
return r, err
}
}
}
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
//
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
//
// Example:
//
// var result mysql.Result
// conn.ExecuteSelectStreaming(`SELECT ... LIMIT 100500`, &result, func(row []mysql.FieldValue) error {
// // Use the row as you want.
// // You must not save FieldValue.AsString() value after this callback is done. Copy it if you need.
// return nil
// })
//
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback) error {
if err := c.writeCommandStr(COM_QUERY, command); err != nil {
return errors.Trace(err)
}
return c.readResultStreaming(false, result, perRowCallback)
}
2016-06-16 09:15:56 +00:00
func (c *Conn) Begin() error {
_, err := c.exec("BEGIN")
return errors.Trace(err)
}
func (c *Conn) Commit() error {
_, err := c.exec("COMMIT")
return errors.Trace(err)
}
func (c *Conn) Rollback() error {
_, err := c.exec("ROLLBACK")
return errors.Trace(err)
}
func (c *Conn) SetCharset(charset string) error {
if c.charset == charset {
return nil
}
if _, err := c.exec(fmt.Sprintf("SET NAMES %s", charset)); err != nil {
return errors.Trace(err)
} else {
c.charset = charset
return nil
}
}
func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
}
data, err := c.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
fs := make([]*Field, 0, 4)
var f *Field
if data[0] == ERR_HEADER {
return nil, c.handleErrorPacket(data)
} else {
for {
if data, err = c.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
// EOF Packet
if c.isEOFPacket(data) {
return fs, nil
}
if f, err = FieldData(data).Parse(); err != nil {
return nil, errors.Trace(err)
}
fs = append(fs, f)
}
}
return nil, fmt.Errorf("field list error")
}
func (c *Conn) SetAutoCommit() error {
if !c.IsAutoCommit() {
if _, err := c.exec("SET AUTOCOMMIT = 1"); err != nil {
return errors.Trace(err)
}
}
return nil
}
func (c *Conn) IsAutoCommit() bool {
return c.status&SERVER_STATUS_AUTOCOMMIT > 0
}
func (c *Conn) IsInTransaction() bool {
return c.status&SERVER_STATUS_IN_TRANS > 0
}
func (c *Conn) GetCharset() string {
return c.charset
}
func (c *Conn) GetConnectionID() uint32 {
return c.connectionID
}
func (c *Conn) HandleOKPacket(data []byte) *Result {
r, _ := c.handleOKPacket(data)
return r
}
func (c *Conn) HandleErrorPacket(data []byte) error {
return c.handleErrorPacket(data)
}
func (c *Conn) ReadOKPacket() (*Result, error) {
return c.readOK()
}
func (c *Conn) exec(query string) (*Result, error) {
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
return nil, errors.Trace(err)
}
return c.readResult(false)
}