Updated go-mysql library, now with fix to decimal. May break datetime

This commit is contained in:
Shlomi Noach 2019-01-01 10:57:46 +02:00
parent ffd8fa0ea8
commit 255314927d
119 changed files with 3615 additions and 16219 deletions

View File

@ -40,12 +40,13 @@ func NewGoMySQLReader(migrationContext *base.MigrationContext) (binlogReader *Go
serverId := uint32(migrationContext.ReplicaServerId)
binlogSyncerConfig := replication.BinlogSyncerConfig{
ServerID: serverId,
Flavor: "mysql",
Host: binlogReader.connectionConfig.Key.Hostname,
Port: uint16(binlogReader.connectionConfig.Key.Port),
User: binlogReader.connectionConfig.User,
Password: binlogReader.connectionConfig.Password,
ServerID: serverId,
Flavor: "mysql",
Host: binlogReader.connectionConfig.Key.Hostname,
Port: uint16(binlogReader.connectionConfig.Key.Port),
User: binlogReader.connectionConfig.User,
Password: binlogReader.connectionConfig.Password,
UseDecimal: true,
}
binlogReader.binlogSyncer = replication.NewBinlogSyncer(binlogSyncerConfig)

View File

@ -1,32 +1,34 @@
language: go
go:
- 1.6
- 1.7
- "1.9"
- "1.10"
dist: trusty
sudo: required
addons:
apt:
sources:
- mysql-5.7-trusty
packages:
- mysql-server-5.6
- mysql-client-core-5.6
- mysql-client-5.6
- mysql-server
- mysql-client
before_install:
- sudo mysql -e "use mysql; update user set authentication_string=PASSWORD('') where User='root'; update user set plugin='mysql_native_password';FLUSH PRIVILEGES;"
- sudo mysql_upgrade
before_script:
# stop mysql and use row-based format binlog
- "sudo /etc/init.d/mysql stop || true"
- "sudo service mysql stop || true"
- "echo '[mysqld]' | sudo tee /etc/mysql/conf.d/replication.cnf"
- "echo 'server-id=1' | sudo tee -a /etc/mysql/conf.d/replication.cnf"
- "echo 'log-bin=mysql' | sudo tee -a /etc/mysql/conf.d/replication.cnf"
- "echo 'log-bin=mysql' | sudo tee -a /etc/mysql/conf.d/replication.cnf"
- "echo 'binlog-format = row' | sudo tee -a /etc/mysql/conf.d/replication.cnf"
# Start mysql (avoid errors to have logs)
- "sudo /etc/init.d/mysql start || true"
- "sudo service mysql start || true"
- "sudo tail -1000 /var/log/syslog"
- mysql -e "CREATE DATABASE IF NOT EXISTS test;" -uroot
script:
- make test
- make test

View File

@ -1,33 +1,14 @@
all: build
build:
rm -rf vendor && ln -s _vendor/vendor vendor
go build -o bin/go-mysqlbinlog cmd/go-mysqlbinlog/main.go
go build -o bin/go-mysqldump cmd/go-mysqldump/main.go
go build -o bin/go-canal cmd/go-canal/main.go
go build -o bin/go-binlogparser cmd/go-binlogparser/main.go
rm -rf vendor
test:
rm -rf vendor && ln -s _vendor/vendor vendor
go test --race -timeout 2m ./...
rm -rf vendor
clean:
go clean -i ./...
@rm -rf ./bin
update_vendor:
which glide >/dev/null || curl https://glide.sh/get | sh
which glide-vc || go get -v -u github.com/sgotti/glide-vc
rm -r vendor && mv _vendor/vendor vendor || true
rm -rf _vendor
ifdef PKG
glide get --strip-vendor --skip-test ${PKG}
else
glide update --strip-vendor --skip-test
endif
@echo "removing test files"
glide vc --only-code --no-tests
mkdir -p _vendor
mv vendor _vendor/vendor
@rm -rf ./bin

View File

@ -15,7 +15,7 @@ import (
"github.com/siddontang/go-mysql/replication"
"os"
)
// Create a binlog syncer with a unique server id, the server id must be different from other MySQL's.
// Create a binlog syncer with a unique server id, the server id must be different from other MySQL's.
// flavor is mysql or mariadb
cfg := replication.BinlogSyncerConfig {
ServerID: 100,
@ -27,7 +27,7 @@ cfg := replication.BinlogSyncerConfig {
}
syncer := replication.NewBinlogSyncer(cfg)
// Start sync with sepcified binlog file and position
// Start sync with specified binlog file and position
streamer, _ := syncer.StartSync(mysql.Position{binlogFile, binlogPos})
// or you can start a gtid replication like
@ -44,7 +44,7 @@ for {
// or we can use a timeout context
for {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
e, _ := s.GetEvent(ctx)
ev, err := s.GetEvent(ctx)
cancel()
if err == context.DeadlineExceeded {
@ -85,13 +85,13 @@ Schema: test
Query: DROP TABLE IF EXISTS `test_replication` /* generated by server */
```
## Canal
## Canal
Canal is a package that can sync your MySQL into everywhere, like Redis, Elasticsearch.
Canal is a package that can sync your MySQL into everywhere, like Redis, Elasticsearch.
First, canal will dump your MySQL data then sync changed data using binlog incrementally.
First, canal will dump your MySQL data then sync changed data using binlog incrementally.
You must use ROW format for binlog, full binlog row image is preferred, because we may meet some errors when primary key changed in update for minimal or noblob row image.
You must use ROW format for binlog, full binlog row image is preferred, because we may meet some errors when primary key changed in update for minimal or noblob row image.
A simple example:
@ -105,30 +105,31 @@ cfg.Dump.Tables = []string{"canal_test"}
c, err := NewCanal(cfg)
type myRowsEventHandler struct {
type MyEventHandler struct {
DummyEventHandler
}
func (h *myRowsEventHandler) Do(e *RowsEvent) error {
func (h *MyEventHandler) OnRow(e *RowsEvent) error {
log.Infof("%s %v\n", e.Action, e.Rows)
return nil
}
func (h *myRowsEventHandler) String() string {
return "myRowsEventHandler"
func (h *MyEventHandler) String() string {
return "MyEventHandler"
}
// Register a handler to handle RowsEvent
c.RegRowsEventHandler(&MyRowsEventHandler{})
c.SetEventHandler(&MyEventHandler{})
// Start canal
c.Start()
```
You can see [go-mysql-elasticsearch](https://github.com/siddontang/go-mysql-elasticsearch) for how to sync MySQL data into Elasticsearch.
You can see [go-mysql-elasticsearch](https://github.com/siddontang/go-mysql-elasticsearch) for how to sync MySQL data into Elasticsearch.
## Client
Client package supports a simple MySQL connection driver which you can use it to communicate with MySQL server.
Client package supports a simple MySQL connection driver which you can use it to communicate with MySQL server.
### Example
@ -137,9 +138,16 @@ import (
"github.com/siddontang/go-mysql/client"
)
// Connect MySQL at 127.0.0.1:3306, with user root, an empty passowrd and database test
// Connect MySQL at 127.0.0.1:3306, with user root, an empty password and database test
conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test")
// Or to use SSL/TLS connection if MySQL server supports TLS
//conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test", func(c *Conn) {c.UseSSL(true)})
// or to set your own client-side certificates for identity verification for security
//tlsConfig := NewClientTLSConfig(caPem, certPem, keyPem, false, "your-server-name")
//conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test", func(c *Conn) {c.SetTLSConfig(tlsConfig)})
conn.Ping()
// Insert
@ -153,13 +161,20 @@ r, _ := conn.Execute(`select id, name from table where id = 1`)
// Handle resultset
v, _ := r.GetInt(0, 0)
v, _ = r.GetIntByName(0, "id")
v, _ = r.GetIntByName(0, "id")
```
Tested MySQL versions for the client include:
- 5.5.x
- 5.6.x
- 5.7.x
- 8.0.x
## Server
Server package supplies a framework to implement a simple MySQL server which can handle the packets from the MySQL client.
You can use it to build your own MySQL proxy.
Server package supplies a framework to implement a simple MySQL server which can handle the packets from the MySQL client.
You can use it to build your own MySQL proxy. The server connection is compatible with MySQL 5.5, 5.6, 5.7, and 8.0 versions,
so that most MySQL clients should be able to connect to the Server without modifications.
### Example
@ -173,42 +188,53 @@ l, _ := net.Listen("tcp", "127.0.0.1:4000")
c, _ := l.Accept()
// Create a connection with user root and an empty passowrd
// We only an empty handler to handle command too
// Create a connection with user root and an empty password.
// You can use your own handler to handle command here.
conn, _ := server.NewConn(c, "root", "", server.EmptyHandler{})
for {
conn.HandleCommand()
}
```
```
Another shell
```
mysql -h127.0.0.1 -P4000 -uroot -p
//Becuase empty handler does nothing, so here the MySQL client can only connect the proxy server. :-)
mysql -h127.0.0.1 -P4000 -uroot -p
//Becuase empty handler does nothing, so here the MySQL client can only connect the proxy server. :-)
```
> ```NewConn()``` will use default server configurations:
> 1. automatically generate default server certificates and enable TLS/SSL support.
> 2. support three mainstream authentication methods **'mysql_native_password'**, **'caching_sha2_password'**, and **'sha256_password'**
> and use **'mysql_native_password'** as default.
> 3. use an in-memory user credential provider to store user and password.
>
> To customize server configurations, use ```NewServer()``` and create connection via ```NewCustomizedConn()```.
## Failover
Failover supports to promote a new master and let other slaves replicate from it automatically when the old master was down.
Failover supports MySQL >= 5.6.9 with GTID mode, if you use lower version, e.g, MySQL 5.0 - 5.5, please use [MHA](http://code.google.com/p/mysql-master-ha/) or [orchestrator](https://github.com/outbrain/orchestrator).
At the same time, Failover supports MariaDB >= 10.0.9 with GTID mode too.
At the same time, Failover supports MariaDB >= 10.0.9 with GTID mode too.
Why only GTID? Supporting failover with no GTID mode is very hard, because slave can not find the proper binlog filename and position with the new master.
Although there are many companies use MySQL 5.0 - 5.5, I think upgrade MySQL to 5.6 or higher is easy.
Why only GTID? Supporting failover with no GTID mode is very hard, because slave can not find the proper binlog filename and position with the new master.
Although there are many companies use MySQL 5.0 - 5.5, I think upgrade MySQL to 5.6 or higher is easy.
## Driver
Driver is the package that you can use go-mysql with go database/sql like other drivers. A simple example:
```
package main
import (
"database/sql"
- "github.com/siddontang/go-mysql/driver"
_ "github.com/siddontang/go-mysql/driver"
)
func main() {
@ -221,9 +247,17 @@ func main() {
We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-)
## Donate
If you like the project and want to buy me a cola, you can through:
|PayPal|微信|
|------|---|
|[![](https://www.paypalobjects.com/webstatic/paypalme/images/pp_logo_small.png)](https://paypal.me/siddontang)|[![](https://github.com/siddontang/blog/blob/master/donate/weixin.png)|
## Feedback
go-mysql is still in development, your feedback is very welcome.
go-mysql is still in development, your feedback is very welcome.
Gmail: siddontang@gmail.com

View File

@ -1,14 +0,0 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -1,492 +0,0 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"math"
"reflect"
"strings"
"time"
)
var e = fmt.Errorf
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
return err
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, rvalue(v))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
copy(context, md.context)
rv.Set(reflect.ValueOf(Primitive{
undecoded: data,
context: context,
}))
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
return md.unifyMap(data, rv)
case reflect.Array:
return md.unifyArray(data, rv)
case reflect.Slice:
return md.unifySlice(data, rv)
case reflect.String:
return md.unifyString(data, rv)
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("Unsupported type '%s'.", rv.Kind())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("Unsupported type '%s'.", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return mismatch(rv, "map", mapping)
}
for key, datum := range tmap {
var f *field
fields := cachedTypeFields(rv.Type())
for i := range fields {
ff := &fields[i]
if ff.name == key {
f = ff
break
}
if f == nil && strings.EqualFold(ff.name, key) {
f = ff
}
}
if f != nil {
subv := rv
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return e("Type mismatch for '%s.%s': %s",
rv.Type().String(), f.name, err)
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("Field '%s.%s' is unexported, and therefore cannot "+
"be loaded with reflection.", rv.Type().String(), f.name)
}
}
}
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if rv.IsNil() {
rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen))
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("Value '%d' is out of range for int8.", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("Value '%d' is out of range for int16.", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("Value '%d' is out of range for int32.", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("Value '%d' is out of range for uint8.", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("Value '%d' is out of range for uint16.", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("Value '%d' is out of range for uint32.", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
}
return nil
}
return badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
}
s = string(text)
case fmt.Stringer:
s = sdata.String()
case string:
s = sdata
case bool:
s = fmt.Sprintf("%v", sdata)
case int64:
s = fmt.Sprintf("%d", sdata)
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanAddr() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
}
}
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("Expected %s but found '%T'.", expected, data)
}
func mismatch(user reflect.Value, expected string, data interface{}) error {
return e("Type mismatch for %s. Expected %s but found '%T'.",
user.Type().String(), expected, data)
}

View File

@ -1,122 +0,0 @@
package toml
import "strings"
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
} else {
return k[i]
}
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
return md.keys
}
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
// and so no keys will be considered decoded.
//
// In this sense, the Undecoded keys correspond to keys in the TOML document
// that do not have a concrete type in your representation.
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
undecoded = append(undecoded, key)
}
}
return undecoded
}

View File

@ -1,27 +0,0 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/mojombo/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
*/
package toml

View File

@ -1,551 +0,0 @@
package toml
import (
"bufio"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"can't encode array with mixed element types")
errArrayNilElement = errors.New(
"can't encode array with nil element")
errNonString = errors.New(
"can't encode a map with non-string key type")
errAnonNonStruct = errors.New(
"can't encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"TOML array element can't contain a table")
errNoKey = errors.New(
"top-level values must be a Go map or struct")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
if terr, ok := r.(tomlEncodeError); ok {
err = terr.error
return
}
panic(r)
}
}()
enc.encode(key, rv)
return nil
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
return
}
k := rv.Kind()
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
}
case reflect.Interface:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Map:
if rv.IsNil() {
return
}
enc.eTable(key, rv)
case reflect.Ptr:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("Unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("Unexpected primitive type: %s", rv.Kind()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
}
return fstr
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
}
}
enc.wf("]")
}
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra new line between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
case reflect.Struct:
enc.eStruct(key, rv)
default:
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
}
// Sort keys so that we have deterministic output. And write keys directly
// underneath this key first, before writing sub-structs or sub-maps.
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
continue
}
enc.encode(key.add(mapKey), mrv)
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexporded fields
if f.PkgPath != "" {
continue
}
frv := rv.Field(i)
if f.Anonymous {
frv := eindirect(frv)
t := frv.Type()
if t.Kind() != reflect.Struct {
encPanic(errAnonNonStruct)
}
addFields(t, frv, f.Index)
} else if typeIsHash(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
continue
}
keyName := sft.Tag.Get("toml")
if keyName == "-" {
continue
}
if keyName == "" {
keyName = sft.Name
}
keyName, opts := getOptions(keyName)
if _, ok := opts["omitempty"]; ok && isEmpty(sf) {
continue
} else if _, ok := opts["omitzero"]; ok && isZero(sf) {
continue
}
enc.encode(key.add(keyName), sf)
}
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInteger
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
} else {
return tomlArray
}
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
}
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
func getOptions(keyName string) (string, map[string]struct{}) {
opts := make(map[string]struct{})
ss := strings.Split(keyName, ",")
name := ss[0]
if len(ss) > 1 {
for _, opt := range ss {
opts[opt] = struct{}{}
}
}
return name, opts
}
func isZero(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if rv.Int() == 0 {
return true
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if rv.Uint() == 0 {
return true
}
case reflect.Float32, reflect.Float64:
if rv.Float() == 0.0 {
return true
}
}
return false
}
func isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.String:
if len(strings.TrimSpace(rv.String())) == 0 {
return true
}
case reflect.Array, reflect.Slice, reflect.Map:
if rv.Len() == 0 {
return true
}
}
return false
}
func (enc *Encoder) newline() {
if enc.hasWritten {
enc.wf("\n")
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
encPanic(err)
}
enc.hasWritten = true
}
func (enc *Encoder) indentStr(key Key) string {
return strings.Repeat(enc.Indent, len(key)-1)
}
func encPanic(err error) {
panic(tomlEncodeError{err})
}
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
return v
}
}
func isNil(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

View File

@ -1,19 +0,0 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

View File

@ -1,18 +0,0 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

View File

@ -1,874 +0,0 @@
package toml
import (
"fmt"
"strings"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemRawString
itemMultilineString
itemRawMultilineString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
)
const (
eof = 0
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
arrayValTerm = ','
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
width int
line int
state stateFn
items chan item
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
}
func lex(input string) *lexer {
lx := &lexer{
input: input + "\n",
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop.")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) emitTrim(typ itemType) {
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.pos >= len(lx.input) {
lx.width = 0
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.pos += lx.width
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only once per call of next.
func (lx *lexer) backup() {
lx.pos -= lx.width
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (new lines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("Unexpected EOF.")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a new line. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a new line for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.ignore()
return lexTop
}
return lx.errorf("Expected a top-level item to end with a new line, "+
"comment or EOF, but got %q instead.", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("Expected end of table array name delimiter %q, "+
"but got %q instead.", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("Unexpected end of table name. (Table names cannot " +
"be empty.)")
case r == tableSep:
return lx.errorf("Unexpected table separator. (Table names cannot " +
"be empty.)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
return lexValue // reuse string lexing
case isWhitespace(r):
return lexTableNameStart
default:
return lexBareTableName
}
}
// lexTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareTableName
case r == tableSep || r == tableEnd:
lx.backup()
lx.emitTrim(itemText)
return lexTableNameEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
case r == tableSep:
lx.ignore()
return lexTableNameStart
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("Expected '.' or ']' to end table name, but got %q "+
"instead.", r)
}
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("Unexpected key separator %q.", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.emit(itemKeyStart)
lx.push(lexKeyEnd)
return lexValue // reuse string lexing
default:
lx.ignore()
lx.emit(itemKeyStart)
return lexBareKey
}
}
// lexBareKey consumes the text of a bare key. Assumes that the first character
// (which is not whitespace) has not yet been consumed.
func lexBareKey(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.emitTrim(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emitTrim(itemText)
return lexKeyEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
// separator).
func lexKeyEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case r == keySep:
return lexSkip(lx, lexValue)
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("Expected key separator %q, but got %q instead.",
keySep, r)
}
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT new lines.
// In array syntax, the array states are responsible for ignoring new
// lines.
r := lx.next()
if isWhitespace(r) {
return lexSkip(lx, lexValue)
}
switch {
case r == arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case r == stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
return lexMultilineString
}
lx.backup()
}
lx.ignore() // ignore the '"'
return lexString
case r == rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
return lexMultilineRawString
}
lx.backup()
}
lx.ignore() // ignore the "'"
return lexRawString
case r == 't':
return lexTrue
case r == 'f':
return lexFalse
case r == '-':
return lexNumberStart
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
case r == '.': // special error case, be kind to users
return lx.errorf("Floats must start with a digit, not '.'.")
}
return lx.errorf("Expected value but found %q instead.", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and new lines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == arrayValTerm:
return lx.errorf("Unexpected array value terminator %q.",
arrayValTerm)
case r == arrayEnd:
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes the cruft between values of an array. Namely,
// it ignores whitespace and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == arrayValTerm:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf("Expected an array value terminator %q or an array "+
"terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r)
}
// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has
// just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == '\\':
lx.push(lexString)
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '\\':
return lexMultilineStringEscape
case r == stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineString
}
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
// It assumes that the beginning "'" has already been consumed and ignored.
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexRawString
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemRawMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineRawString
}
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
lx.next()
return lexMultilineString
} else {
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
}
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '\\':
return lx.pop()
case 'u':
return lexShortUnicodeEscape
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("Invalid escape character %q. Only the following "+
"escape characters are allowed: "+
"\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+
"\\uXXXX and \\UXXXXXXXX.", r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected four hexadecimal digits after '\\u', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected eight hexadecimal digits after '\\U', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either a (positive) integer, float or
// datetime. It assumes that NO negative sign has been consumed.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumberOrDate
}
// lexNumberOrDate consumes either a (positive) integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '-':
if lx.pos-lx.start != 5 {
return lx.errorf("All ISO8601 dates must be in full Zulu form.")
}
return lexDateAfterYear
case isDigit(r):
return lexNumberOrDate
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format.
// It assumes that "YYYY-" has already been consumed.
func lexDateAfterYear(lx *lexer) stateFn {
formats := []rune{
// digits are '0'.
// everything else is direct equality.
'0', '0', '-', '0', '0',
'T',
'0', '0', ':', '0', '0', ':', '0', '0',
'Z',
}
for _, f := range formats {
r := lx.next()
if f == '0' {
if !isDigit(r) {
return lx.errorf("Expected digit in ISO8601 datetime, "+
"but found %q instead.", r)
}
} else if f != r {
return lx.errorf("Expected %q in ISO8601 datetime, "+
"but found %q instead.", f, r)
}
}
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that
// a negative sign has already been read, but that *no* digits have been
// consumed. lexNumberStart will move to the appropriate integer or float
// states.
func lexNumberStart(lx *lexer) stateFn {
// we MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
switch {
case isDigit(r):
return lexNumber
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloatStart starts the consumption of digits of a float after a '.'.
// Namely, at least one digit is required.
func lexFloatStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
return lx.errorf("Floats must have a digit after the '.', but got "+
"%q instead.", r)
}
return lexFloat
}
// lexFloat consumes the digits of a float after a '.'.
// Assumes that one digit has been consumed after a '.' already.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexConst consumes the s[1:] in s. It assumes that s[0] has already been
// consumed.
func lexConst(lx *lexer, s string) stateFn {
for i := range s[1:] {
if r := lx.next(); r != rune(s[i+1]) {
return lx.errorf("Expected %q, but found %q instead.", s[:i+1],
s[:i]+string(r))
}
}
return nil
}
// lexTrue consumes the "rue" in "true". It assumes that 't' has already
// been consumed.
func lexTrue(lx *lexer) stateFn {
if fn := lexConst(lx, "true"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexFalse consumes the "alse" in "false". It assumes that 'f' has already
// been consumed.
func lexFalse(lx *lexer) stateFn {
if fn := lexConst(lx, "false"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first new line character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' ||
r == '-'
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString:
return "String"
case itemRawString:
return "String"
case itemMultilineString:
return "String"
case itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}

View File

@ -1,498 +0,0 @@
package toml
import (
"fmt"
"log"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panicf("%s", it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// Gets a string for a key (or part of a key in a table name).
func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
num, err := strconv.ParseInt(it.val, 10, 64)
if err != nil {
// See comment below for floats describing why we make a
// distinction between a bug and a user error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
num, err := strconv.ParseFloat(it.val, 64)
if err != nil {
// Distinguish float values. Normally, it'd be a bug if the lexer
// provides an invalid float, but it's possible that the float is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
//
// This is also true for integers.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.bug("Expected float value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
t, err := time.Parse("2006-01-02T15:04:05Z", it.val)
if err != nil {
p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
return s
}
return s[1:len(s)]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
}
}
return strings.Join(esc, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
continue
}
r += 1
if r >= len(s) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
replaced = append(replaced, escaped)
r += 9
}
}
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
// BUG(burntsushi)
// I honestly don't understand how this works. I can't seem
// to find a way to make this fail. I figured this would fail on invalid
// UTF-8 characters like U+DCFF, but it doesn't.
if !utf8.ValidString(string(rune(hex))) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

View File

@ -1,91 +0,0 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
if t1 == nil || t2 == nil {
return false
}
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemMultilineString:
return tomlString
case itemRawString:
return tomlString
case itemRawMultilineString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

View File

@ -1,241 +0,0 @@
package toml
// Struct field handling is adapted from code in encoding/json:
//
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the Go distribution.
import (
"reflect"
"sort"
"sync"
)
// A field represents a single field found in a struct.
type field struct {
name string // the name of the field (`toml` tag included)
tag bool // whether field has a `toml` tag
index []int // represents the depth of an anonymous field
typ reflect.Type // the type of the field
}
// byName sorts field by name, breaking ties with depth,
// then breaking ties with "name came from toml tag", then
// breaking ties with index sequence.
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that TOML should recognize for the given
// type. The algorithm is breadth-first search over the set of structs to
// include - the top struct and then any reachable anonymous structs.
func typeFields(t reflect.Type) []field {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.PkgPath != "" { // unexported
continue
}
name := sf.Tag.Get("toml")
if name == "-" {
continue
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
// Follow pointer.
ft = ft.Elem()
}
// Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != ""
if name == "" {
name = sf.Name
}
fields = append(fields, field{name, tagged, index, ft})
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
f := field{name: ft.Name(), index: index, typ: ft}
next = append(next, f)
}
}
}
}
sort.Sort(byName(fields))
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with TOML tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
return fields
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// TOML tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order. The winner
// must therefore be one with the shortest index length. Drop all
// longer entries, which is easy: just truncate the slice.
length := len(fields[0].index)
tagged := -1 // Index of first tagged field.
for i, f := range fields {
if len(f.index) > length {
fields = fields[:i]
break
}
if f.tag {
if tagged >= 0 {
// Multiple tagged fields at the same level: conflict.
// Return no field.
return field{}, false
}
tagged = i
}
}
if tagged >= 0 {
return fields[tagged], true
}
// All remaining fields have the same length. If there's more than one,
// we have a conflict (two fields named "X" at the same level) and we
// return no field.
if len(fields) > 1 {
return field{}, false
}
return fields[0], true
}
var fieldCache struct {
sync.RWMutex
m map[reflect.Type][]field
}
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) []field {
fieldCache.RLock()
f := fieldCache.m[t]
fieldCache.RUnlock()
if f != nil {
return f
}
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = typeFields(t)
if f == nil {
f = []field{}
}
fieldCache.Lock()
if fieldCache.m == nil {
fieldCache.m = map[reflect.Type][]field{}
}
fieldCache.m[t] = f
fieldCache.Unlock()
return f
}

View File

@ -1,373 +0,0 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

View File

@ -1,19 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build appengine
package mysql
import (
"appengine/cloudsql"
)
func init() {
RegisterDial("cloudsql", cloudsql.Dial)
}

View File

@ -1,147 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"io"
"net"
"time"
)
const defaultBufSize = 4096
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
nc net.Conn
idx int
length int
timeout time.Duration
}
func newBuffer(nc net.Conn) buffer {
var b [defaultBufSize]byte
return buffer{
buf: b[:],
nc: nc,
}
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
n := b.length
// move existing data to the beginning
if n > 0 && b.idx > 0 {
copy(b.buf[0:n], b.buf[b.idx:])
}
// grow buffer if necessary
// TODO: let the buffer shrink again at some point
// Maybe keep the org buf slice and swap back?
if need > len(b.buf) {
// Round up to the next multiple of the default size
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
copy(newBuf, b.buf)
b.buf = newBuf
}
b.idx = 0
for {
if b.timeout > 0 {
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
n += nn
switch err {
case nil:
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
}
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
}
// returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
if b.length > 0 {
return nil
}
// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
}
return make([]byte, length)
}
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length == 0 {
return b.buf[:length]
}
return nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
if b.length == 0 {
return b.buf
}
return nil
}

View File

@ -1,250 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const defaultCollation = "utf8_general_ci"
// A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query:
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS
var collations = map[string]byte{
"big5_chinese_ci": 1,
"latin2_czech_cs": 2,
"dec8_swedish_ci": 3,
"cp850_general_ci": 4,
"latin1_german1_ci": 5,
"hp8_english_ci": 6,
"koi8r_general_ci": 7,
"latin1_swedish_ci": 8,
"latin2_general_ci": 9,
"swe7_swedish_ci": 10,
"ascii_general_ci": 11,
"ujis_japanese_ci": 12,
"sjis_japanese_ci": 13,
"cp1251_bulgarian_ci": 14,
"latin1_danish_ci": 15,
"hebrew_general_ci": 16,
"tis620_thai_ci": 18,
"euckr_korean_ci": 19,
"latin7_estonian_cs": 20,
"latin2_hungarian_ci": 21,
"koi8u_general_ci": 22,
"cp1251_ukrainian_ci": 23,
"gb2312_chinese_ci": 24,
"greek_general_ci": 25,
"cp1250_general_ci": 26,
"latin2_croatian_ci": 27,
"gbk_chinese_ci": 28,
"cp1257_lithuanian_ci": 29,
"latin5_turkish_ci": 30,
"latin1_german2_ci": 31,
"armscii8_general_ci": 32,
"utf8_general_ci": 33,
"cp1250_czech_cs": 34,
"ucs2_general_ci": 35,
"cp866_general_ci": 36,
"keybcs2_general_ci": 37,
"macce_general_ci": 38,
"macroman_general_ci": 39,
"cp852_general_ci": 40,
"latin7_general_ci": 41,
"latin7_general_cs": 42,
"macce_bin": 43,
"cp1250_croatian_ci": 44,
"utf8mb4_general_ci": 45,
"utf8mb4_bin": 46,
"latin1_bin": 47,
"latin1_general_ci": 48,
"latin1_general_cs": 49,
"cp1251_bin": 50,
"cp1251_general_ci": 51,
"cp1251_general_cs": 52,
"macroman_bin": 53,
"utf16_general_ci": 54,
"utf16_bin": 55,
"utf16le_general_ci": 56,
"cp1256_general_ci": 57,
"cp1257_bin": 58,
"cp1257_general_ci": 59,
"utf32_general_ci": 60,
"utf32_bin": 61,
"utf16le_bin": 62,
"binary": 63,
"armscii8_bin": 64,
"ascii_bin": 65,
"cp1250_bin": 66,
"cp1256_bin": 67,
"cp866_bin": 68,
"dec8_bin": 69,
"greek_bin": 70,
"hebrew_bin": 71,
"hp8_bin": 72,
"keybcs2_bin": 73,
"koi8r_bin": 74,
"koi8u_bin": 75,
"latin2_bin": 77,
"latin5_bin": 78,
"latin7_bin": 79,
"cp850_bin": 80,
"cp852_bin": 81,
"swe7_bin": 82,
"utf8_bin": 83,
"big5_bin": 84,
"euckr_bin": 85,
"gb2312_bin": 86,
"gbk_bin": 87,
"sjis_bin": 88,
"tis620_bin": 89,
"ucs2_bin": 90,
"ujis_bin": 91,
"geostd8_general_ci": 92,
"geostd8_bin": 93,
"latin1_spanish_ci": 94,
"cp932_japanese_ci": 95,
"cp932_bin": 96,
"eucjpms_japanese_ci": 97,
"eucjpms_bin": 98,
"cp1250_polish_ci": 99,
"utf16_unicode_ci": 101,
"utf16_icelandic_ci": 102,
"utf16_latvian_ci": 103,
"utf16_romanian_ci": 104,
"utf16_slovenian_ci": 105,
"utf16_polish_ci": 106,
"utf16_estonian_ci": 107,
"utf16_spanish_ci": 108,
"utf16_swedish_ci": 109,
"utf16_turkish_ci": 110,
"utf16_czech_ci": 111,
"utf16_danish_ci": 112,
"utf16_lithuanian_ci": 113,
"utf16_slovak_ci": 114,
"utf16_spanish2_ci": 115,
"utf16_roman_ci": 116,
"utf16_persian_ci": 117,
"utf16_esperanto_ci": 118,
"utf16_hungarian_ci": 119,
"utf16_sinhala_ci": 120,
"utf16_german2_ci": 121,
"utf16_croatian_ci": 122,
"utf16_unicode_520_ci": 123,
"utf16_vietnamese_ci": 124,
"ucs2_unicode_ci": 128,
"ucs2_icelandic_ci": 129,
"ucs2_latvian_ci": 130,
"ucs2_romanian_ci": 131,
"ucs2_slovenian_ci": 132,
"ucs2_polish_ci": 133,
"ucs2_estonian_ci": 134,
"ucs2_spanish_ci": 135,
"ucs2_swedish_ci": 136,
"ucs2_turkish_ci": 137,
"ucs2_czech_ci": 138,
"ucs2_danish_ci": 139,
"ucs2_lithuanian_ci": 140,
"ucs2_slovak_ci": 141,
"ucs2_spanish2_ci": 142,
"ucs2_roman_ci": 143,
"ucs2_persian_ci": 144,
"ucs2_esperanto_ci": 145,
"ucs2_hungarian_ci": 146,
"ucs2_sinhala_ci": 147,
"ucs2_german2_ci": 148,
"ucs2_croatian_ci": 149,
"ucs2_unicode_520_ci": 150,
"ucs2_vietnamese_ci": 151,
"ucs2_general_mysql500_ci": 159,
"utf32_unicode_ci": 160,
"utf32_icelandic_ci": 161,
"utf32_latvian_ci": 162,
"utf32_romanian_ci": 163,
"utf32_slovenian_ci": 164,
"utf32_polish_ci": 165,
"utf32_estonian_ci": 166,
"utf32_spanish_ci": 167,
"utf32_swedish_ci": 168,
"utf32_turkish_ci": 169,
"utf32_czech_ci": 170,
"utf32_danish_ci": 171,
"utf32_lithuanian_ci": 172,
"utf32_slovak_ci": 173,
"utf32_spanish2_ci": 174,
"utf32_roman_ci": 175,
"utf32_persian_ci": 176,
"utf32_esperanto_ci": 177,
"utf32_hungarian_ci": 178,
"utf32_sinhala_ci": 179,
"utf32_german2_ci": 180,
"utf32_croatian_ci": 181,
"utf32_unicode_520_ci": 182,
"utf32_vietnamese_ci": 183,
"utf8_unicode_ci": 192,
"utf8_icelandic_ci": 193,
"utf8_latvian_ci": 194,
"utf8_romanian_ci": 195,
"utf8_slovenian_ci": 196,
"utf8_polish_ci": 197,
"utf8_estonian_ci": 198,
"utf8_spanish_ci": 199,
"utf8_swedish_ci": 200,
"utf8_turkish_ci": 201,
"utf8_czech_ci": 202,
"utf8_danish_ci": 203,
"utf8_lithuanian_ci": 204,
"utf8_slovak_ci": 205,
"utf8_spanish2_ci": 206,
"utf8_roman_ci": 207,
"utf8_persian_ci": 208,
"utf8_esperanto_ci": 209,
"utf8_hungarian_ci": 210,
"utf8_sinhala_ci": 211,
"utf8_german2_ci": 212,
"utf8_croatian_ci": 213,
"utf8_unicode_520_ci": 214,
"utf8_vietnamese_ci": 215,
"utf8_general_mysql500_ci": 223,
"utf8mb4_unicode_ci": 224,
"utf8mb4_icelandic_ci": 225,
"utf8mb4_latvian_ci": 226,
"utf8mb4_romanian_ci": 227,
"utf8mb4_slovenian_ci": 228,
"utf8mb4_polish_ci": 229,
"utf8mb4_estonian_ci": 230,
"utf8mb4_spanish_ci": 231,
"utf8mb4_swedish_ci": 232,
"utf8mb4_turkish_ci": 233,
"utf8mb4_czech_ci": 234,
"utf8mb4_danish_ci": 235,
"utf8mb4_lithuanian_ci": 236,
"utf8mb4_slovak_ci": 237,
"utf8mb4_spanish2_ci": 238,
"utf8mb4_roman_ci": 239,
"utf8mb4_persian_ci": 240,
"utf8mb4_esperanto_ci": 241,
"utf8mb4_hungarian_ci": 242,
"utf8mb4_sinhala_ci": 243,
"utf8mb4_german2_ci": 244,
"utf8mb4_croatian_ci": 245,
"utf8mb4_unicode_520_ci": 246,
"utf8mb4_vietnamese_ci": 247,
}
// A blacklist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[string]bool{
"big5_chinese_ci": true,
"sjis_japanese_ci": true,
"gbk_chinese_ci": true,
"big5_bin": true,
"gb2312_bin": true,
"gbk_bin": true,
"sjis_bin": true,
"cp932_japanese_ci": true,
"cp932_bin": true,
}

View File

@ -1,372 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"net"
"strconv"
"strings"
"time"
)
type mysqlConn struct {
buf buffer
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *Config
maxPacketAllowed int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
strict bool
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.Params {
switch param {
// Charset
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
return
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
err := mc.exec("START TRANSACTION")
if err == nil {
return &mysqlTx{mc}, err
}
return nil, err
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if mc.netConn != nil {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
return
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
// Makes cleanup idempotent
if mc.netConn != nil {
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
}
mc.netConn = nil
}
mc.cfg = nil
mc.buf.nc = nil
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
return nil, err
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
buf := mc.buf.takeCompleteBuffer()
if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return "", driver.ErrBadConn
}
buf = buf[:0]
argPos := 0
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.Loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
year1 := year % 100
month := v.Month()
day := v.Day()
hour := v.Hour()
minute := v.Minute()
second := v.Second()
micro := v.Nanosecond() / 1000
buf = append(buf, []byte{
'\'',
digits10[year100], digits01[year100],
digits10[year1], digits01[year1],
'-',
digits10[month], digits01[month],
'-',
digits10[day], digits01[day],
' ',
digits10[hour], digits01[hour],
':',
digits10[minute], digits01[minute],
':',
digits10[second], digits01[second],
}...)
if micro != 0 {
micro10000 := micro / 10000
micro100 := micro / 100 % 100
micro1 := micro % 100
buf = append(buf, []byte{
'.',
digits10[micro10000], digits01[micro10000],
digits10[micro100], digits01[micro100],
digits10[micro1], digits01[micro1],
}...)
}
buf = append(buf, '\'')
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxPacketAllowed {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err != nil {
return err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil && resLen > 0 {
if err = mc.readUntilEOF(); err != nil {
return err
}
err = mc.readUntilEOF()
}
return err
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
return rows, err
}
}
return nil, err
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}

View File

@ -1,163 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const (
minProtocolVersion byte = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"
)
// MySQL constants documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
const (
iOK byte = 0x00
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
)
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
type clientFlag uint32
const (
clientLongPassword clientFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
clientNoSchema
clientCompress
clientODBC
clientLocalFiles
clientIgnoreSpace
clientProtocol41
clientInteractive
clientSSL
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientMultiStatements
clientMultiResults
clientPSMultiResults
clientPluginAuth
clientConnectAttrs
clientPluginAuthLenEncClientData
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
)
const (
comQuit byte = iota + 1
comInitDB
comQuery
comFieldList
comCreateDB
comDropDB
comRefresh
comShutdown
comStatistics
comProcessInfo
comConnect
comProcessKill
comDebug
comPing
comTime
comDelayedInsert
comChangeUser
comBinlogDump
comTableDump
comConnectOut
comRegisterSlave
comStmtPrepare
comStmtExecute
comStmtSendLongData
comStmtClose
comStmtReset
comSetOption
comStmtFetch
)
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
const (
fieldTypeDecimal byte = iota
fieldTypeTiny
fieldTypeShort
fieldTypeLong
fieldTypeFloat
fieldTypeDouble
fieldTypeNULL
fieldTypeTimestamp
fieldTypeLongLong
fieldTypeInt24
fieldTypeDate
fieldTypeTime
fieldTypeDateTime
fieldTypeYear
fieldTypeNewDate
fieldTypeVarChar
fieldTypeBit
)
const (
fieldTypeJSON byte = iota + 0xf5
fieldTypeNewDecimal
fieldTypeEnum
fieldTypeSet
fieldTypeTinyBLOB
fieldTypeMediumBLOB
fieldTypeLongBLOB
fieldTypeBLOB
fieldTypeVarString
fieldTypeString
fieldTypeGeometry
)
type fieldFlag uint16
const (
flagNotNULL fieldFlag = 1 << iota
flagPriKey
flagUniqueKey
flagMultipleKey
flagBLOB
flagUnsigned
flagZeroFill
flagBinary
flagEnum
flagAutoIncrement
flagTimestamp
flagSet
flagUnknown1
flagUnknown2
flagUnknown3
flagUnknown4
)
// http://dev.mysql.com/doc/internals/en/status-flags.html
type statusFlag uint16
const (
statusInTrans statusFlag = 1 << iota
statusInAutocommit
statusReserved // Not in documentation
statusMoreResultsExists
statusNoGoodIndexUsed
statusNoIndexUsed
statusCursorExists
statusLastRowSent
statusDbDropped
statusNoBackslashEscapes
statusMetadataChanged
statusQueryWasSlow
statusPsOutParams
statusInTransReadonly
statusSessionStateChanged
)

View File

@ -1,167 +0,0 @@
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Package mysql provides a MySQL driver for Go's database/sql package
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
// See https://github.com/go-sql-driver/mysql#usage for details
package mysql
import (
"database/sql"
"database/sql/driver"
"net"
)
// MySQLDriver is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
// DialFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDial
type DialFunc func(addr string) (net.Conn, error)
var dials map[string]DialFunc
// RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
func RegisterDial(net string, dial DialFunc) {
if dials == nil {
dials = make(map[string]DialFunc)
}
dials[net] = dial
}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
}
mc.cfg, err = ParseDSN(dsn)
if err != nil {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
mc.strict = mc.cfg.Strict
// Connect to Server
if dial, ok := dials[mc.cfg.Net]; ok {
mc.netConn, err = dial(mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
if err != nil {
mc.cleanup()
return nil, err
}
// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = handleAuthResult(mc, cipher); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxPacketAllowed = stringToInt(maxap) - 1
if mc.maxPacketAllowed < maxPacketSize {
mc.maxWriteSize = mc.maxPacketAllowed
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
func handleAuthResult(mc *mysqlConn, cipher []byte) error {
// Read Result Packet
err := mc.readResultOK()
if err == nil {
return nil // auth successful
}
if mc.cfg == nil {
return err // auth failed and retry not possible
}
// Retry auth if configured to do so.
if mc.cfg.AllowOldPasswords && err == ErrOldPassword {
// Retry with old authentication method. Note: there are edge cases
// where this should work but doesn't; this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
if err = mc.writeOldAuthPacket(cipher); err != nil {
return err
}
err = mc.readResultOK()
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword {
// Retry with clear text password for
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
if err = mc.writeClearAuthPacket(); err != nil {
return err
}
err = mc.readResultOK()
}
return err
}
func init() {
sql.Register("mysql", &MySQLDriver{})
}

View File

@ -1,513 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"strings"
"time"
)
var (
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
)
// Config is a configuration parsed from a DSN string
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowOldPasswords bool // Allows the old insecure password method
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
Strict bool // Return warnings as errors
}
// FormatDSN formats the given Config into a DSN string which can be passed to
// the driver.
func (cfg *Config) FormatDSN() string {
var buf bytes.Buffer
// [username[:password]@]
if len(cfg.User) > 0 {
buf.WriteString(cfg.User)
if len(cfg.Passwd) > 0 {
buf.WriteByte(':')
buf.WriteString(cfg.Passwd)
}
buf.WriteByte('@')
}
// [protocol[(address)]]
if len(cfg.Net) > 0 {
buf.WriteString(cfg.Net)
if len(cfg.Addr) > 0 {
buf.WriteByte('(')
buf.WriteString(cfg.Addr)
buf.WriteByte(')')
}
}
// /dbname
buf.WriteByte('/')
buf.WriteString(cfg.DBName)
// [?param1=value1&...&paramN=valueN]
hasParam := false
if cfg.AllowAllFiles {
hasParam = true
buf.WriteString("?allowAllFiles=true")
}
if cfg.AllowCleartextPasswords {
if hasParam {
buf.WriteString("&allowCleartextPasswords=true")
} else {
hasParam = true
buf.WriteString("?allowCleartextPasswords=true")
}
}
if cfg.AllowOldPasswords {
if hasParam {
buf.WriteString("&allowOldPasswords=true")
} else {
hasParam = true
buf.WriteString("?allowOldPasswords=true")
}
}
if cfg.ClientFoundRows {
if hasParam {
buf.WriteString("&clientFoundRows=true")
} else {
hasParam = true
buf.WriteString("?clientFoundRows=true")
}
}
if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
if hasParam {
buf.WriteString("&collation=")
} else {
hasParam = true
buf.WriteString("?collation=")
}
buf.WriteString(col)
}
if cfg.ColumnsWithAlias {
if hasParam {
buf.WriteString("&columnsWithAlias=true")
} else {
hasParam = true
buf.WriteString("?columnsWithAlias=true")
}
}
if cfg.InterpolateParams {
if hasParam {
buf.WriteString("&interpolateParams=true")
} else {
hasParam = true
buf.WriteString("?interpolateParams=true")
}
}
if cfg.Loc != time.UTC && cfg.Loc != nil {
if hasParam {
buf.WriteString("&loc=")
} else {
hasParam = true
buf.WriteString("?loc=")
}
buf.WriteString(url.QueryEscape(cfg.Loc.String()))
}
if cfg.MultiStatements {
if hasParam {
buf.WriteString("&multiStatements=true")
} else {
hasParam = true
buf.WriteString("?multiStatements=true")
}
}
if cfg.ParseTime {
if hasParam {
buf.WriteString("&parseTime=true")
} else {
hasParam = true
buf.WriteString("?parseTime=true")
}
}
if cfg.ReadTimeout > 0 {
if hasParam {
buf.WriteString("&readTimeout=")
} else {
hasParam = true
buf.WriteString("?readTimeout=")
}
buf.WriteString(cfg.ReadTimeout.String())
}
if cfg.Strict {
if hasParam {
buf.WriteString("&strict=true")
} else {
hasParam = true
buf.WriteString("?strict=true")
}
}
if cfg.Timeout > 0 {
if hasParam {
buf.WriteString("&timeout=")
} else {
hasParam = true
buf.WriteString("?timeout=")
}
buf.WriteString(cfg.Timeout.String())
}
if len(cfg.TLSConfig) > 0 {
if hasParam {
buf.WriteString("&tls=")
} else {
hasParam = true
buf.WriteString("?tls=")
}
buf.WriteString(url.QueryEscape(cfg.TLSConfig))
}
if cfg.WriteTimeout > 0 {
if hasParam {
buf.WriteString("&writeTimeout=")
} else {
hasParam = true
buf.WriteString("?writeTimeout=")
}
buf.WriteString(cfg.WriteTimeout.String())
}
// other params
if cfg.Params != nil {
for param, value := range cfg.Params {
if hasParam {
buf.WriteByte('&')
} else {
hasParam = true
buf.WriteByte('?')
}
buf.WriteString(param)
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(value))
}
}
return buf.String()
}
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *Config, err error) {
// New config with some default values
cfg = &Config{
Loc: time.UTC,
Collation: defaultCollation,
}
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.Passwd = dsn[k+1 : j]
break
}
}
cfg.User = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.Addr = dsn[k+1 : i-1]
break
}
}
cfg.Net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.DBName = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
return nil, errInvalidDSNUnsafeCollation
}
// Set default network if empty
if cfg.Net == "" {
cfg.Net = "tcp"
}
// Set default address if empty
if cfg.Addr == "" {
switch cfg.Net {
case "tcp":
cfg.Addr = "127.0.0.1:3306"
case "unix":
cfg.Addr = "/tmp/mysql.sock"
default:
return nil, errors.New("default addr for network '" + cfg.Net + "' unknown")
}
}
return
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// cfg params
switch value := param[1]; param[0] {
// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.AllowAllFiles, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use cleartext authentication mode (MySQL 5.5.10+)
case "allowCleartextPasswords":
var isBool bool
cfg.AllowCleartextPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.AllowOldPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.ClientFoundRows, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Collation
case "collation":
cfg.Collation = value
break
case "columnsWithAlias":
var isBool bool
cfg.ColumnsWithAlias, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Compression
case "compress":
return errors.New("compression not implemented yet")
// Enable client side placeholder substitution
case "interpolateParams":
var isBool bool
cfg.InterpolateParams, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.Loc, err = time.LoadLocation(value)
if err != nil {
return
}
// multiple statements in one query
case "multiStatements":
var isBool bool
cfg.MultiStatements, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// time.Time parsing
case "parseTime":
var isBool bool
cfg.ParseTime, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// I/O read Timeout
case "readTimeout":
cfg.ReadTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
// Strict mode
case "strict":
var isBool bool
cfg.Strict, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Dial Timeout
case "timeout":
cfg.Timeout, err = time.ParseDuration(value)
if err != nil {
return
}
// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.TLSConfig = "true"
cfg.tls = &tls.Config{}
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else {
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for TLS config name: %v", err)
}
if tlsConfig, ok := tlsConfigRegister[name]; ok {
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
tlsConfig.ServerName = host
}
}
cfg.TLSConfig = name
cfg.tls = tlsConfig
} else {
return errors.New("invalid value / unknown config name: " + name)
}
}
// I/O write Timeout
case "writeTimeout":
cfg.WriteTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
default:
// lazy init
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}
return
}

View File

@ -1,131 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
"os"
)
// Various errors the driver might return. Can change between driver versions.
var (
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
)
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
// Logger is used to log critical error messages.
type Logger interface {
Print(v ...interface{})
}
// SetLogger is used to set the logger for critical errors.
// The initial logger is os.Stderr.
func SetLogger(logger Logger) error {
if logger == nil {
return errors.New("logger is nil")
}
errLog = logger
return nil
}
// MySQLError is an error type which represents a single MySQL error
type MySQLError struct {
Number uint16
Message string
}
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}
// MySQLWarnings is an error type which represents a group of one or more MySQL
// warnings
type MySQLWarnings []MySQLWarning
func (mws MySQLWarnings) Error() string {
var msg string
for i, warning := range mws {
if i > 0 {
msg += "\r\n"
}
msg += fmt.Sprintf(
"%s %s: %s",
warning.Level,
warning.Code,
warning.Message,
)
}
return msg
}
// MySQLWarning is an error type which represents a single MySQL warning.
// Warnings are returned in groups only. See MySQLWarnings
type MySQLWarning struct {
Level string
Code string
Message string
}
func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", nil)
if err != nil {
return
}
var warnings = MySQLWarnings{}
var values = make([]driver.Value, 3)
for {
err = rows.Next(values)
switch err {
case nil:
warning := MySQLWarning{}
if raw, ok := values[0].([]byte); ok {
warning.Level = string(raw)
} else {
warning.Level = fmt.Sprintf("%s", values[0])
}
if raw, ok := values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}
if raw, ok := values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}
warnings = append(warnings, warning)
case io.EOF:
return warnings
default:
rows.Close()
return
}
}
}

View File

@ -1,181 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"fmt"
"io"
"os"
"strings"
"sync"
)
var (
fileRegister map[string]bool
fileRegisterLock sync.RWMutex
readerRegister map[string]func() io.Reader
readerRegisterLock sync.RWMutex
)
// RegisterLocalFile adds the given file to the file whitelist,
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
// Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true'
//
// filePath := "/home/gopher/data.csv"
// mysql.RegisterLocalFile(filePath)
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterLocalFile(filePath string) {
fileRegisterLock.Lock()
// lazy map init
if fileRegister == nil {
fileRegister = make(map[string]bool)
}
fileRegister[strings.Trim(filePath, `"`)] = true
fileRegisterLock.Unlock()
}
// DeregisterLocalFile removes the given filepath from the whitelist.
func DeregisterLocalFile(filePath string) {
fileRegisterLock.Lock()
delete(fileRegister, strings.Trim(filePath, `"`))
fileRegisterLock.Unlock()
}
// RegisterReaderHandler registers a handler function which is used
// to receive a io.Reader.
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
// If the handler returns a io.ReadCloser Close() is called when the
// request is finished.
//
// mysql.RegisterReaderHandler("data", func() io.Reader {
// var csvReader io.Reader // Some Reader that returns CSV data
// ... // Open Reader here
// return csvReader
// })
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterReaderHandler(name string, handler func() io.Reader) {
readerRegisterLock.Lock()
// lazy map init
if readerRegister == nil {
readerRegister = make(map[string]func() io.Reader)
}
readerRegister[name] = handler
readerRegisterLock.Unlock()
}
// DeregisterReaderHandler removes the ReaderHandler function with
// the given name from the registry.
func DeregisterReaderHandler(name string) {
readerRegisterLock.Lock()
delete(readerRegister, name)
readerRegisterLock.Unlock()
}
func deferredClose(err *error, closer io.Closer) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize
}
if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
// The server might return an an absolute path. See issue #355.
name = name[idx+8:]
readerRegisterLock.RLock()
handler, inMap := readerRegister[name]
readerRegisterLock.RUnlock()
if inMap {
rdr = handler()
if rdr != nil {
if cl, ok := rdr.(io.Closer); ok {
defer deferredClose(&err, cl)
}
} else {
err = fmt.Errorf("Reader '%s' is <nil>", name)
}
} else {
err = fmt.Errorf("Reader '%s' is not registered", name)
}
} else { // File
name = strings.Trim(name, `"`)
fileRegisterLock.RLock()
fr := fileRegister[name]
fileRegisterLock.RUnlock()
if mc.cfg.AllowAllFiles || fr {
var file *os.File
var fi os.FileInfo
if file, err = os.Open(name); err == nil {
defer deferredClose(&err, file)
// get file size
if fi, err = file.Stat(); err == nil {
rdr = file
if fileSize := int(fi.Size()); fileSize < packetSize {
packetSize = fileSize
}
}
}
} else {
err = fmt.Errorf("local file '%s' is not registered", name)
}
}
// send content packets
if err == nil {
data := make([]byte, 4+packetSize)
var n int
for err == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
return ioErr
}
}
}
if err == io.EOF {
err = nil
}
}
// send empty packet (termination)
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
return ioErr
}
// read OK packet
if err == nil {
return mc.readResultOK()
}
mc.readPacket()
return err
}

View File

@ -1,22 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlResult struct {
affectedRows int64
insertId int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil
}

View File

@ -1,112 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"io"
)
type mysqlField struct {
tableName string
name string
flags fieldFlag
fieldType byte
decimals byte
}
type mysqlRows struct {
mc *mysqlConn
columns []mysqlField
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
type emptyRows struct{}
func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
for i := range columns {
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.columns[i].name
} else {
columns[i] = rows.columns[i].name
}
}
} else {
for i := range columns {
columns[i] = rows.columns[i].name
}
}
return columns
}
func (rows *mysqlRows) Close() error {
mc := rows.mc
if mc == nil {
return nil
}
if mc.netConn == nil {
return ErrInvalidConn
}
// Remove unread packets from stream
err := mc.readUntilEOF()
if err == nil {
if err = mc.discardResults(); err != nil {
return err
}
}
rows.mc = nil
return err
}
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return ErrInvalidConn
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return ErrInvalidConn
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}
func (rows emptyRows) Columns() []string {
return nil
}
func (rows emptyRows) Close() error {
return nil
}
func (rows emptyRows) Next(dest []driver.Value) error {
return io.EOF
}

View File

@ -1,150 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
)
type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
columns []mysqlField // cached from the first query
}
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
stmt.mc = nil
return err
}
func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
mc.affectedRows = 0
mc.insertId = 0
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
if resLen > 0 {
// Columns
err = mc.readUntilEOF()
if err != nil {
return nil, err
}
// Rows
err = mc.readUntilEOF()
}
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, nil
}
}
return nil, err
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
if resLen > 0 {
rows.mc = mc
// Columns
// If not cached, read them and cache them
if stmt.columns == nil {
rows.columns, err = mc.readColumns(resLen)
stmt.columns = rows.columns
} else {
rows.columns = stmt.columns
err = mc.readUntilEOF()
}
}
return rows, err
}
type converter struct{}
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return c.ConvertValue(rv.Elem().Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return strconv.FormatUint(u64, 10), nil
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}

View File

@ -1,31 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlTx struct {
mc *mysqlConn
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}

View File

@ -1,740 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/sha1"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"strings"
"time"
)
var (
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
)
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
// Use the key as a value in the DSN where tls=value.
//
// rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
// if err != nil {
// log.Fatal(err)
// }
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
// log.Fatal("Failed to append PEM.")
// }
// clientCert := make([]tls.Certificate, 0, 1)
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
// if err != nil {
// log.Fatal(err)
// }
// clientCert = append(clientCert, certs)
// mysql.RegisterTLSConfig("custom", &tls.Config{
// RootCAs: rootCertPool,
// Certificates: clientCert,
// })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
return fmt.Errorf("key '%s' is reserved", key)
}
if tlsConfigRegister == nil {
tlsConfigRegister = make(map[string]*tls.Config)
}
tlsConfigRegister[key] = config
return nil
}
// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
if tlsConfigRegister != nil {
delete(tlsConfigRegister, key)
}
}
// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}
/******************************************************************************
* Authentication *
******************************************************************************/
// Encrypt password using 4.1+ method
func scramblePassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write(password)
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
// Encrypt password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
seed1, seed2 uint32
}
const myRndMaxVal = 0x3FFFFFFF
// Pseudo random number generator
func newMyRnd(seed1, seed2 uint32) *myRnd {
return &myRnd{
seed1: seed1 % myRndMaxVal,
seed2: seed2 % myRndMaxVal,
}
}
// Tested to be equivalent to MariaDB's floating point variant
// http://play.golang.org/p/QHvhd4qved
// http://play.golang.org/p/RG0q4ElWDx
func (r *myRnd) NextByte() byte {
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
}
// Generate binary hash from byte string using insecure pre 4.1 method
func pwHash(password []byte) (result [2]uint32) {
var add uint32 = 7
var tmp uint32
result[0] = 1345345333
result[1] = 0x12345671
for _, c := range password {
// skip spaces and tabs in password
if c == ' ' || c == '\t' {
continue
}
tmp = uint32(c)
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
result[1] += (result[1] << 8) ^ result[0]
add += tmp
}
// Remove sign bit (1<<31)-1)
result[0] &= 0x7FFFFFFF
result[1] &= 0x7FFFFFFF
return
}
// Encrypt password using insecure pre 4.1 method
func scrambleOldPassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
scramble = scramble[:8]
hashPw := pwHash(password)
hashSc := pwHash(scramble)
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
var out [8]byte
for i := range out {
out[i] = r.NextByte() + 64
}
mask := r.NextByte()
for i := range out {
out[i] ^= mask
}
return out[:]
}
/******************************************************************************
* Time related utils *
******************************************************************************/
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}
switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
}
nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
base := "0000-00-00 00:00:00.0000000"
switch len(str) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
}
t, err = time.Parse(timeFormat[:len(str)], str)
default:
err = fmt.Errorf("invalid time string: %s", str)
return
}
// Adjust location
if err == nil && loc != time.UTC {
y, mo, d := t.Date()
h, mi, s := t.Clock()
t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
}
return
}
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
switch num {
case 0:
return time.Time{}, nil
case 4:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
0, 0, 0, 0,
loc,
), nil
case 7:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
0,
loc,
), nil
case 11:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
loc,
), nil
}
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
}
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
// if the DATE or DATETIME has the zero value.
// It must never be changed.
// The current behavior depends on database/sql copying the result.
var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) {
// length expects the deterministic length of the zero value,
// negative time and 100+ hours are automatically added if needed
if len(src) == 0 {
if justTime {
return zeroDateTime[11 : 11+length], nil
}
return zeroDateTime[:length], nil
}
var dst []byte // return value
var pt, p1, p2, p3 byte // current digit pair
var zOffs byte // offset of value in zeroDateTime
if justTime {
switch length {
case
8, // time (can be up to 10 when negative and 100+ hours)
10, 11, 12, 13, 14, 15: // time with fractional seconds
default:
return nil, fmt.Errorf("illegal TIME length %d", length)
}
switch len(src) {
case 8, 12:
default:
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
}
// +2 to enable negative time and 100+ hours
dst = make([]byte, 0, length+2)
if src[0] == 1 {
dst = append(dst, '-')
}
if src[1] != 0 {
hour := uint16(src[1])*24 + uint16(src[5])
pt = byte(hour / 100)
p1 = byte(hour - 100*uint16(pt))
dst = append(dst, digits01[pt])
} else {
p1 = src[5]
}
zOffs = 11
src = src[6:]
} else {
switch length {
case 10, 19, 21, 22, 23, 24, 25, 26:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s length %d", t, length)
}
switch len(src) {
case 4, 7, 11:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
}
dst = make([]byte, 0, length)
// start with the date
year := binary.LittleEndian.Uint16(src[:2])
pt = byte(year / 100)
p1 = byte(year - 100*uint16(pt))
p2, p3 = src[2], src[3]
dst = append(dst,
digits10[pt], digits01[pt],
digits10[p1], digits01[p1], '-',
digits10[p2], digits01[p2], '-',
digits10[p3], digits01[p3],
)
if length == 10 {
return dst, nil
}
if len(src) == 4 {
return append(dst, zeroDateTime[10:length]...), nil
}
dst = append(dst, ' ')
p1 = src[4] // hour
src = src[5:]
}
// p1 is 2-digit hour, src is after hour
p2, p3 = src[0], src[1]
dst = append(dst,
digits10[p1], digits01[p1], ':',
digits10[p2], digits01[p2], ':',
digits10[p3], digits01[p3],
)
if length <= byte(len(dst)) {
return dst, nil
}
src = src[2:]
if len(src) == 0 {
return append(dst, zeroDateTime[19:zOffs+length]...), nil
}
microsecs := binary.LittleEndian.Uint32(src[:4])
p1 = byte(microsecs / 10000)
microsecs -= 10000 * uint32(p1)
p2 = byte(microsecs / 100)
microsecs -= 100 * uint32(p2)
p3 = byte(microsecs)
switch decimals := zOffs + length - 20; decimals {
default:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3], digits01[p3],
), nil
case 1:
return append(dst, '.',
digits10[p1],
), nil
case 2:
return append(dst, '.',
digits10[p1], digits01[p1],
), nil
case 3:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2],
), nil
case 4:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
), nil
case 5:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3],
), nil
}
}
/******************************************************************************
* Convert from and to bytes *
******************************************************************************/
func uint64ToBytes(n uint64) []byte {
return []byte{
byte(n),
byte(n >> 8),
byte(n >> 16),
byte(n >> 24),
byte(n >> 32),
byte(n >> 40),
byte(n >> 48),
byte(n >> 56),
}
}
func uint64ToString(n uint64) []byte {
var a [20]byte
i := 20
// U+0030 = 0
// ...
// U+0039 = 9
var q uint64
for n >= 10 {
i--
q = n / 10
a[i] = uint8(n-q*10) + 0x30
n = q
}
i--
a[i] = uint8(n) + 0x30
return a[i:]
}
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, wheter the value is NULL,
// the number of bytes read and an error, in case the string is longer than
// the input slice
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
// Get length
num, isNull, n := readLengthEncodedInteger(b)
if num < 1 {
return b[n:n], isNull, n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return b[n-int(num) : n], false, n, nil
}
return nil, false, n, io.EOF
}
// returns the number of bytes skipped and an error, in case the string is
// longer than the input slice
func skipLengthEncodedString(b []byte) (int, error) {
// Get length
num, _, n := readLengthEncodedInteger(b)
if num < 1 {
return n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return n, nil
}
return n, io.EOF
}
// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
if len(b) == 0 {
return 0, true, 1
}
switch b[0] {
// 251: NULL
case 0xfb:
return 0, true, 1
// 252: value of following 2
case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3
// 253: value of following 3
case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
// 254: value of following 8
case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
}
// 0-250: value of first byte
return uint64(b[0]), false, 1
}
// encodes a uint64 value and appends it to the given bytes slice
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
switch {
case n <= 250:
return append(b, byte(n))
case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8))
case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
}
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
newSize := len(buf) + appendSize
if cap(buf) < newSize {
// Grow buffer exponentially
newBuf := make([]byte, len(buf)*2+appendSize)
copy(newBuf, buf)
buf = newBuf
}
return buf[:newSize]
}
// escapeBytesBackslash escapes []byte with backslashes (\)
// This escapes the contents of a string (provided as []byte) by adding backslashes before special
// characters, and turning others into specific escape sequences, such as
// turning newlines into \n and null bytes into \0.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
func escapeBytesBackslash(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
switch c {
case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0'
pos += 2
case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n'
pos += 2
case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r'
pos += 2
case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z'
pos += 2
case '\'':
buf[pos] = '\\'
buf[pos+1] = '\''
pos += 2
case '"':
buf[pos] = '\\'
buf[pos+1] = '"'
pos += 2
case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
func escapeStringBackslash(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
c := v[i]
switch c {
case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0'
pos += 2
case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n'
pos += 2
case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r'
pos += 2
case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z'
pos += 2
case '\'':
buf[pos] = '\\'
buf[pos+1] = '\''
pos += 2
case '"':
buf[pos] = '\\'
buf[pos+1] = '"'
pos += 2
case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
// This escapes the contents of a string by doubling up any apostrophes that
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
// effect on the server.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
func escapeBytesQuotes(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
func escapeStringQuotes(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
c := v[i]
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}

View File

@ -1,23 +0,0 @@
Copyright (c) 2013, Jason Moiron
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,186 +0,0 @@
package sqlx
import (
"bytes"
"errors"
"reflect"
"strconv"
"strings"
"github.com/jmoiron/sqlx/reflectx"
)
// Bindvar types supported by Rebind, BindMap and BindStruct.
const (
UNKNOWN = iota
QUESTION
DOLLAR
NAMED
)
// BindType returns the bindtype for a given database given a drivername.
func BindType(driverName string) int {
switch driverName {
case "postgres", "pgx":
return DOLLAR
case "mysql":
return QUESTION
case "sqlite3":
return QUESTION
case "oci8":
return NAMED
}
return UNKNOWN
}
// FIXME: this should be able to be tolerant of escaped ?'s in queries without
// losing much speed, and should be to avoid confusion.
// Rebind a query from the default bindtype (QUESTION) to the target bindtype.
func Rebind(bindType int, query string) string {
switch bindType {
case QUESTION, UNKNOWN:
return query
}
qb := []byte(query)
// Add space enough for 10 params before we have to allocate
rqb := make([]byte, 0, len(qb)+10)
j := 1
for _, b := range qb {
if b == '?' {
switch bindType {
case DOLLAR:
rqb = append(rqb, '$')
case NAMED:
rqb = append(rqb, ':', 'a', 'r', 'g')
}
for _, b := range strconv.Itoa(j) {
rqb = append(rqb, byte(b))
}
j++
} else {
rqb = append(rqb, b)
}
}
return string(rqb)
}
// Experimental implementation of Rebind which uses a bytes.Buffer. The code is
// much simpler and should be more resistant to odd unicode, but it is twice as
// slow. Kept here for benchmarking purposes and to possibly replace Rebind if
// problems arise with its somewhat naive handling of unicode.
func rebindBuff(bindType int, query string) string {
if bindType != DOLLAR {
return query
}
b := make([]byte, 0, len(query))
rqb := bytes.NewBuffer(b)
j := 1
for _, r := range query {
if r == '?' {
rqb.WriteRune('$')
rqb.WriteString(strconv.Itoa(j))
j++
} else {
rqb.WriteRune(r)
}
}
return rqb.String()
}
// In expands slice values in args, returning the modified query string
// and a new arg list that can be executed by a database. The `query` should
// use the `?` bindVar. The return value uses the `?` bindVar.
func In(query string, args ...interface{}) (string, []interface{}, error) {
// argMeta stores reflect.Value and length for slices and
// the value itself for non-slice arguments
type argMeta struct {
v reflect.Value
i interface{}
length int
}
var flatArgsCount int
var anySlices bool
meta := make([]argMeta, len(args))
for i, arg := range args {
v := reflect.ValueOf(arg)
t := reflectx.Deref(v.Type())
if t.Kind() == reflect.Slice {
meta[i].length = v.Len()
meta[i].v = v
anySlices = true
flatArgsCount += meta[i].length
if meta[i].length == 0 {
return "", nil, errors.New("empty slice passed to 'in' query")
}
} else {
meta[i].i = arg
flatArgsCount++
}
}
// don't do any parsing if there aren't any slices; note that this means
// some errors that we might have caught below will not be returned.
if !anySlices {
return query, args, nil
}
newArgs := make([]interface{}, 0, flatArgsCount)
var arg, offset int
var buf bytes.Buffer
for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') {
if arg >= len(meta) {
// if an argument wasn't passed, lets return an error; this is
// not actually how database/sql Exec/Query works, but since we are
// creating an argument list programmatically, we want to be able
// to catch these programmer errors earlier.
return "", nil, errors.New("number of bindVars exceeds arguments")
}
argMeta := meta[arg]
arg++
// not a slice, continue.
// our questionmark will either be written before the next expansion
// of a slice or after the loop when writing the rest of the query
if argMeta.length == 0 {
offset = offset + i + 1
newArgs = append(newArgs, argMeta.i)
continue
}
// write everything up to and including our ? character
buf.WriteString(query[:offset+i+1])
newArgs = append(newArgs, argMeta.v.Index(0).Interface())
for si := 1; si < argMeta.length; si++ {
buf.WriteString(", ?")
newArgs = append(newArgs, argMeta.v.Index(si).Interface())
}
// slice the query and reset the offset. this avoids some bookkeeping for
// the write after the loop
query = query[offset+i+1:]
offset = 0
}
buf.WriteString(query)
if arg < len(meta) {
return "", nil, errors.New("number of bindVars less than number arguments")
}
return buf.String(), newArgs, nil
}

View File

@ -1,12 +0,0 @@
// Package sqlx provides general purpose extensions to database/sql.
//
// It is intended to seamlessly wrap database/sql and provide convenience
// methods which are useful in the development of database driven applications.
// None of the underlying database/sql methods are changed. Instead all extended
// behavior is implemented through new methods defined on wrapper types.
//
// Additions include scanning into structs, named query support, rebinding
// queries for different drivers, convenient shorthands for common error handling
// and more.
//
package sqlx

View File

@ -1,336 +0,0 @@
package sqlx
// Named Query Support
//
// * BindMap - bind query bindvars to map/struct args
// * NamedExec, NamedQuery - named query w/ struct or map
// * NamedStmt - a pre-compiled named query which is a prepared statement
//
// Internal Interfaces:
//
// * compileNamedQuery - rebind a named query, returning a query and list of names
// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
//
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
"unicode"
"github.com/jmoiron/sqlx/reflectx"
)
// NamedStmt is a prepared statement that executes named queries. Prepare it
// how you would execute a NamedQuery, but pass in a struct or map when executing.
type NamedStmt struct {
Params []string
QueryString string
Stmt *Stmt
}
// Close closes the named statement.
func (n *NamedStmt) Close() error {
return n.Stmt.Close()
}
// Exec executes a named statement using the struct passed.
func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
return n.Stmt.Exec(args...)
}
// Query executes a named statement using the struct argument, returning rows.
func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
return n.Stmt.Query(args...)
}
// QueryRow executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
func (n *NamedStmt) QueryRow(arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
return n.Stmt.QueryRowx(args...)
}
// MustExec execs a NamedStmt, panicing on error
func (n *NamedStmt) MustExec(arg interface{}) sql.Result {
res, err := n.Exec(arg)
if err != nil {
panic(err)
}
return res
}
// Queryx using this NamedStmt
func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
r, err := n.Query(arg)
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
}
// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow.
func (n *NamedStmt) QueryRowx(arg interface{}) *Row {
return n.QueryRow(arg)
}
// Select using this NamedStmt
func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
rows, err := n.Queryx(arg)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// Get using this NamedStmt
func (n *NamedStmt) Get(dest interface{}, arg interface{}) error {
r := n.QueryRowx(arg)
return r.scanAny(dest, false)
}
// Unsafe creates an unsafe version of the NamedStmt
func (n *NamedStmt) Unsafe() *NamedStmt {
r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString}
r.Stmt.unsafe = true
return r
}
// A union interface of preparer and binder, required to be able to prepare
// named statements (as the bindtype must be determined).
type namedPreparer interface {
Preparer
binder
}
func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) {
bindType := BindType(p.DriverName())
q, args, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return nil, err
}
stmt, err := Preparex(p, q)
if err != nil {
return nil, err
}
return &NamedStmt{
QueryString: q,
Params: args,
Stmt: stmt,
}, nil
}
func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMapArgs(names, maparg)
}
return bindArgs(names, arg, m)
}
// private interface to generate a list of interfaces from a given struct
// type, given a list of names to pull out of the struct. Used by public
// BindStruct interface.
func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
// grab the indirected value of arg
v := reflect.ValueOf(arg)
for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; {
v = v.Elem()
}
fields := m.TraversalsByName(v.Type(), names)
for i, t := range fields {
if len(t) == 0 {
return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg)
}
val := reflectx.FieldByIndexesReadOnly(v, t)
arglist = append(arglist, val.Interface())
}
return arglist, nil
}
// like bindArgs, but for maps.
func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
for _, name := range names {
val, ok := arg[name]
if !ok {
return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
}
arglist = append(arglist, val)
}
return arglist, nil
}
// bindStruct binds a named parameter query with fields from a struct argument.
// The rules for binding field names to parameter names follow the same
// conventions as for StructScan, including obeying the `db` struct tags.
func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindArgs(names, arg, m)
if err != nil {
return "", []interface{}{}, err
}
return bound, arglist, nil
}
// bindMap binds a named parameter query with a map of arguments.
func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindMapArgs(names, args)
return bound, arglist, err
}
// -- Compilation of Named Queries
// Allow digits and letters in bind params; additionally runes are
// checked against underscores, meaning that bind params can have be
// alphanumeric with underscores. Mind the difference between unicode
// digits and numbers, where '5' is a digit but '五' is not.
var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}
// FIXME: this function isn't safe for unicode named params, as a failing test
// can testify. This is not a regression but a failure of the original code
// as well. It should be modified to range over runes in a string rather than
// bytes, even though this is less convenient and slower. Hopefully the
// addition of the prepared NamedStmt (which will only do this once) will make
// up for the slightly slower ad-hoc NamedExec/NamedQuery.
// compile a NamedQuery into an unbound query (using the '?' bindvar) and
// a list of names.
func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
names = make([]string, 0, 10)
rebound := make([]byte, 0, len(qs))
inName := false
last := len(qs) - 1
currentVar := 1
name := make([]byte, 0, 10)
for i, b := range qs {
// a ':' while we're in a name is an error
if b == ':' {
// if this is the second ':' in a '::' escape sequence, append a ':'
if inName && i > 0 && qs[i-1] == ':' {
rebound = append(rebound, ':')
inName = false
continue
} else if inName {
err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
return query, names, err
}
inName = true
name = []byte{}
// if we're in a name, and this is an allowed character, continue
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_') && i != last {
// append the byte to the name if we are in a name and not on the last byte
name = append(name, b)
// if we're in a name and it's not an allowed character, the name is done
} else if inName {
inName = false
// if this is the final byte of the string and it is part of the name, then
// make sure to add it to the name
if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
name = append(name, b)
}
// add the string representation to the names list
names = append(names, string(name))
// add a proper bindvar for the bindType
switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
rebound = append(rebound, ':')
rebound = append(rebound, name...)
case QUESTION, UNKNOWN:
rebound = append(rebound, '?')
case DOLLAR:
rebound = append(rebound, '$')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
}
currentVar++
}
// add this byte to string unless it was not part of the name
if i != last {
rebound = append(rebound, b)
} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
rebound = append(rebound, b)
}
} else {
// this is a normal byte and should just go onto the rebound query
rebound = append(rebound, b)
}
}
return string(rebound), names, err
}
// BindNamed binds a struct or a map to a query with named parameters.
// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future.
func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(bindType, query, arg, mapper())
}
// Named takes a query using named parameters and an argument and
// returns a new query with a list of args that can be executed by
// a database. The return value uses the `?` bindvar.
func Named(query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(QUESTION, query, arg, mapper())
}
func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMap(bindType, query, maparg)
}
return bindStruct(bindType, query, arg, m)
}
// NamedQuery binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Queryx(q, args...)
}
// NamedExec uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
// or the query excution itself.
func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Exec(q, args...)
}

View File

@ -1,371 +0,0 @@
// Package reflectx implements extensions to the standard reflect lib suitable
// for implementing marshaling and unmarshaling packages. The main Mapper type
// allows for Go-compatible named atribute access, including accessing embedded
// struct attributes and the ability to use functions and struct tags to
// customize field names.
//
package reflectx
import (
"fmt"
"reflect"
"runtime"
"strings"
"sync"
)
// A FieldInfo is a collection of metadata about a struct field.
type FieldInfo struct {
Index []int
Path string
Field reflect.StructField
Zero reflect.Value
Name string
Options map[string]string
Embedded bool
Children []*FieldInfo
Parent *FieldInfo
}
// A StructMap is an index of field metadata for a struct.
type StructMap struct {
Tree *FieldInfo
Index []*FieldInfo
Paths map[string]*FieldInfo
Names map[string]*FieldInfo
}
// GetByPath returns a *FieldInfo for a given string path.
func (f StructMap) GetByPath(path string) *FieldInfo {
return f.Paths[path]
}
// GetByTraversal returns a *FieldInfo for a given integer path. It is
// analagous to reflect.FieldByIndex.
func (f StructMap) GetByTraversal(index []int) *FieldInfo {
if len(index) == 0 {
return nil
}
tree := f.Tree
for _, i := range index {
if i >= len(tree.Children) || tree.Children[i] == nil {
return nil
}
tree = tree.Children[i]
}
return tree
}
// Mapper is a general purpose mapper of names to struct fields. A Mapper
// behaves like most marshallers, optionally obeying a field tag for name
// mapping and a function to provide a basic mapping of fields to names.
type Mapper struct {
cache map[reflect.Type]*StructMap
tagName string
tagMapFunc func(string) string
mapFunc func(string) string
mutex sync.Mutex
}
// NewMapper returns a new mapper which optionally obeys the field tag given
// by tagName. If tagName is the empty string, it is ignored.
func NewMapper(tagName string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
}
}
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
// AND a mapper for tag values. This is useful for tags like json which can
// have values like "name,omitempty".
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: mapFunc,
tagMapFunc: tagMapFunc,
}
}
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
// a struct field name mapper func given by f. Tags will take precedence, but
// for any other field, the mapped name will be f(field.Name)
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: f,
}
}
// TypeMap returns a mapping of field strings to int slices representing
// the traversal down the struct to reach the field.
func (m *Mapper) TypeMap(t reflect.Type) *StructMap {
m.mutex.Lock()
mapping, ok := m.cache[t]
if !ok {
mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
m.cache[t] = mapping
}
m.mutex.Unlock()
return mapping
}
// FieldMap returns the mapper's mapping of field names to reflect values. Panics
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
r := map[string]reflect.Value{}
tm := m.TypeMap(v.Type())
for tagName, fi := range tm.Names {
r[tagName] = FieldByIndexes(v, fi.Index)
}
return r
}
// FieldByName returns a field by the its mapped name as a reflect.Value.
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
// Returns zero Value if the name is not found.
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
fi, ok := tm.Names[name]
if !ok {
return v
}
return FieldByIndexes(v, fi.Index)
}
// FieldsByName returns a slice of values corresponding to the slice of names
// for the value. Panics if v's Kind is not Struct or v is not Indirectable
// to a struct Kind. Returns zero Value for each name not found.
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
vals := make([]reflect.Value, 0, len(names))
for _, name := range names {
fi, ok := tm.Names[name]
if !ok {
vals = append(vals, *new(reflect.Value))
} else {
vals = append(vals, FieldByIndexes(v, fi.Index))
}
}
return vals
}
// TraversalsByName returns a slice of int slices which represent the struct
// traversals for each mapped name. Panics if t is not a struct or Indirectable
// to a struct. Returns empty int slice for each name not found.
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
t = Deref(t)
mustBe(t, reflect.Struct)
tm := m.TypeMap(t)
r := make([][]int, 0, len(names))
for _, name := range names {
fi, ok := tm.Names[name]
if !ok {
r = append(r, []int{})
} else {
r = append(r, fi.Index)
}
}
return r
}
// FieldByIndexes returns a value for a particular struct traversal.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
// if this is a pointer, it's possible it is nil
if v.Kind() == reflect.Ptr && v.IsNil() {
alloc := reflect.New(Deref(v.Type()))
v.Set(alloc)
}
if v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}
}
return v
}
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
// but is not concerned with allocating nil pointers because the value is
// going to be used for reading and not setting.
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
}
return v
}
// Deref is Indirect for reflect.Types
func Deref(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
// -- helpers & utilities --
type kinder interface {
Kind() reflect.Kind
}
// mustBe checks a value against a kind, panicing with a reflect.ValueError
// if the kind isn't that which is required.
func mustBe(v kinder, expected reflect.Kind) {
k := v.Kind()
if k != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: k})
}
}
// methodName is returns the caller of the function calling methodName
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}
type typeQueue struct {
t reflect.Type
fi *FieldInfo
pp string // Parent path
}
// A copying append that creates a new slice each time.
func apnd(is []int, i int) []int {
x := make([]int, len(is)+1)
for p, n := range is {
x[p] = n
}
x[len(x)-1] = i
return x
}
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
// tagMapFunc to determine the canonical names of fields.
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap {
m := []*FieldInfo{}
root := &FieldInfo{}
queue := []typeQueue{}
queue = append(queue, typeQueue{Deref(t), root, ""})
for len(queue) != 0 {
// pop the first item off of the queue
tq := queue[0]
queue = queue[1:]
nChildren := 0
if tq.t.Kind() == reflect.Struct {
nChildren = tq.t.NumField()
}
tq.fi.Children = make([]*FieldInfo, nChildren)
// iterate through all of its fields
for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
f := tq.t.Field(fieldPos)
fi := FieldInfo{}
fi.Field = f
fi.Zero = reflect.New(f.Type).Elem()
fi.Options = map[string]string{}
var tag, name string
if tagName != "" && strings.Contains(string(f.Tag), tagName+":") {
tag = f.Tag.Get(tagName)
name = tag
} else {
if mapFunc != nil {
name = mapFunc(f.Name)
}
}
parts := strings.Split(name, ",")
if len(parts) > 1 {
name = parts[0]
for _, opt := range parts[1:] {
kv := strings.Split(opt, "=")
if len(kv) > 1 {
fi.Options[kv[0]] = kv[1]
} else {
fi.Options[kv[0]] = ""
}
}
}
if tagMapFunc != nil {
tag = tagMapFunc(tag)
}
fi.Name = name
if tq.pp == "" || (tq.pp == "" && tag == "") {
fi.Path = fi.Name
} else {
fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name)
}
// if the name is "-", disabled via a tag, skip it
if name == "-" {
continue
}
// skip unexported fields
if len(f.PkgPath) != 0 {
continue
}
// bfs search of anonymous embedded structs
if f.Anonymous {
pp := tq.pp
if tag != "" {
pp = fi.Path
}
fi.Embedded = true
fi.Index = apnd(tq.fi.Index, fieldPos)
nChildren := 0
ft := Deref(f.Type)
if ft.Kind() == reflect.Struct {
nChildren = ft.NumField()
}
fi.Children = make([]*FieldInfo, nChildren)
queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
}
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Parent = tq.fi
tq.fi.Children[fieldPos] = &fi
m = append(m, &fi)
}
}
flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}}
for _, fi := range flds.Index {
flds.Paths[fi.Path] = fi
if fi.Name != "" && !fi.Embedded {
flds.Names[fi.Path] = fi
}
}
return flds
}

View File

@ -1,992 +0,0 @@
package sqlx
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
"strings"
"github.com/jmoiron/sqlx/reflectx"
)
// Although the NameMapper is convenient, in practice it should not
// be relied on except for application code. If you are writing a library
// that uses sqlx, you should be aware that the name mappings you expect
// can be overridded by your user's application.
// NameMapper is used to map column names to struct field names. By default,
// it uses strings.ToLower to lowercase struct field names. It can be set
// to whatever you want, but it is encouraged to be set before sqlx is used
// as name-to-field mappings are cached after first use on a type.
var NameMapper = strings.ToLower
var origMapper = reflect.ValueOf(NameMapper)
// Rather than creating on init, this is created when necessary so that
// importers have time to customize the NameMapper.
var mpr *reflectx.Mapper
// mapper returns a valid mapper using the configured NameMapper func.
func mapper() *reflectx.Mapper {
if mpr == nil {
mpr = reflectx.NewMapperFunc("db", NameMapper)
} else if origMapper != reflect.ValueOf(NameMapper) {
// if NameMapper has changed, create a new mapper
mpr = reflectx.NewMapperFunc("db", NameMapper)
origMapper = reflect.ValueOf(NameMapper)
}
return mpr
}
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. Something is scannable if:
// * it is not a struct
// * it implements sql.Scanner
// * it has no exported fields
func isScannable(t reflect.Type) bool {
if reflect.PtrTo(t).Implements(_scannerInterface) {
return true
}
if t.Kind() != reflect.Struct {
return true
}
// it's not important that we use the right mapper for this particular object,
// we're only concerned on how many exported fields this struct has
m := mapper()
if len(m.TypeMap(t).Index) == 0 {
return true
}
return false
}
// ColScanner is an interface used by MapScan and SliceScan
type ColScanner interface {
Columns() ([]string, error)
Scan(dest ...interface{}) error
Err() error
}
// Queryer is an interface used by Get and Select
type Queryer interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Queryx(query string, args ...interface{}) (*Rows, error)
QueryRowx(query string, args ...interface{}) *Row
}
// Execer is an interface used by MustExec and LoadFile
type Execer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
// Binder is an interface for something which can bind queries (Tx, DB)
type binder interface {
DriverName() string
Rebind(string) string
BindNamed(string, interface{}) (string, []interface{}, error)
}
// Ext is a union interface which can bind, query, and exec, used by
// NamedQuery and NamedExec.
type Ext interface {
binder
Queryer
Execer
}
// Preparer is an interface used by Preparex.
type Preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
// determine if any of our extensions are unsafe
func isUnsafe(i interface{}) bool {
switch v := i.(type) {
case Row:
return v.unsafe
case *Row:
return v.unsafe
case Rows:
return v.unsafe
case *Rows:
return v.unsafe
case NamedStmt:
return v.Stmt.unsafe
case *NamedStmt:
return v.Stmt.unsafe
case Stmt:
return v.unsafe
case *Stmt:
return v.unsafe
case qStmt:
return v.unsafe
case *qStmt:
return v.unsafe
case DB:
return v.unsafe
case *DB:
return v.unsafe
case Tx:
return v.unsafe
case *Tx:
return v.unsafe
case sql.Rows, *sql.Rows:
return false
default:
return false
}
}
func mapperFor(i interface{}) *reflectx.Mapper {
switch i.(type) {
case DB:
return i.(DB).Mapper
case *DB:
return i.(*DB).Mapper
case Tx:
return i.(Tx).Mapper
case *Tx:
return i.(*Tx).Mapper
default:
return mapper()
}
}
var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// Row is a reimplementation of sql.Row in order to gain access to the underlying
// sql.Rows.Columns() data, necessary for StructScan.
type Row struct {
err error
unsafe bool
rows *sql.Rows
Mapper *reflectx.Mapper
}
// Scan is a fixed implementation of sql.Row.Scan, which does not discard the
// underlying error from the internal rows object if it exists.
func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil {
return r.err
}
// TODO(bradfitz): for now we need to defensively clone all
// []byte that the driver returned (not permitting
// *RawBytes in Rows.Scan), since we're about to close
// the Rows in our defer, when we return from this function.
// the contract with the driver.Next(...) interface is that it
// can return slices into read-only temporary memory that's
// only valid until the next Scan/Close. But the TODO is that
// for a lot of drivers, this copy will be unnecessary. We
// should provide an optional interface for drivers to
// implement to say, "don't worry, the []bytes that I return
// from Next will not be modified again." (for instance, if
// they were obtained from the network anyway) But for now we
// don't care.
defer r.rows.Close()
for _, dp := range dest {
if _, ok := dp.(*sql.RawBytes); ok {
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
}
}
if !r.rows.Next() {
if err := r.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := r.rows.Scan(dest...)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
if err := r.rows.Close(); err != nil {
return err
}
return nil
}
// Columns returns the underlying sql.Rows.Columns(), or the deferred error usually
// returned by Row.Scan()
func (r *Row) Columns() ([]string, error) {
if r.err != nil {
return []string{}, r.err
}
return r.rows.Columns()
}
// Err returns the error encountered while scanning.
func (r *Row) Err() error {
return r.err
}
// DB is a wrapper around sql.DB which keeps track of the driverName upon Open,
// used mostly to automatically bind named queries using the right bindvars.
type DB struct {
*sql.DB
driverName string
unsafe bool
Mapper *reflectx.Mapper
}
// NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The
// driverName of the original database is required for named query support.
func NewDb(db *sql.DB, driverName string) *DB {
return &DB{DB: db, driverName: driverName, Mapper: mapper()}
}
// DriverName returns the driverName passed to the Open function for this DB.
func (db *DB) DriverName() string {
return db.driverName
}
// Open is the same as sql.Open, but returns an *sqlx.DB instead.
func Open(driverName, dataSourceName string) (*DB, error) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
return nil, err
}
return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err
}
// MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error.
func MustOpen(driverName, dataSourceName string) *DB {
db, err := Open(driverName, dataSourceName)
if err != nil {
panic(err)
}
return db
}
// MapperFunc sets a new mapper for this db using the default sqlx struct tag
// and the provided mapper function.
func (db *DB) MapperFunc(mf func(string) string) {
db.Mapper = reflectx.NewMapperFunc("db", mf)
}
// Rebind transforms a query from QUESTION to the DB driver's bindvar type.
func (db *DB) Rebind(query string) string {
return Rebind(BindType(db.driverName), query)
}
// Unsafe returns a version of DB which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its
// safety behavior.
func (db *DB) Unsafe() *DB {
return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper}
}
// BindNamed binds a query using the DB driver's bindvar type.
func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper)
}
// NamedQuery using this DB.
func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) {
return NamedQuery(db, query, arg)
}
// NamedExec using this DB.
func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) {
return NamedExec(db, query, arg)
}
// Select using this DB.
func (db *DB) Select(dest interface{}, query string, args ...interface{}) error {
return Select(db, dest, query, args...)
}
// Get using this DB.
func (db *DB) Get(dest interface{}, query string, args ...interface{}) error {
return Get(db, dest, query, args...)
}
// MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead
// of an *sql.Tx.
func (db *DB) MustBegin() *Tx {
tx, err := db.Beginx()
if err != nil {
panic(err)
}
return tx
}
// Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx.
func (db *DB) Beginx() (*Tx, error) {
tx, err := db.DB.Begin()
if err != nil {
return nil, err
}
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
}
// Queryx queries the database and returns an *sqlx.Rows.
func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := db.DB.Query(query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
}
// QueryRowx queries the database and returns an *sqlx.Row.
func (db *DB) QueryRowx(query string, args ...interface{}) *Row {
rows, err := db.DB.Query(query, args...)
return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
}
// MustExec (panic) runs MustExec using this database.
func (db *DB) MustExec(query string, args ...interface{}) sql.Result {
return MustExec(db, query, args...)
}
// Preparex returns an sqlx.Stmt instead of a sql.Stmt
func (db *DB) Preparex(query string) (*Stmt, error) {
return Preparex(db, query)
}
// PrepareNamed returns an sqlx.NamedStmt
func (db *DB) PrepareNamed(query string) (*NamedStmt, error) {
return prepareNamed(db, query)
}
// Tx is an sqlx wrapper around sql.Tx with extra functionality
type Tx struct {
*sql.Tx
driverName string
unsafe bool
Mapper *reflectx.Mapper
}
// DriverName returns the driverName used by the DB which began this transaction.
func (tx *Tx) DriverName() string {
return tx.driverName
}
// Rebind a query within a transaction's bindvar type.
func (tx *Tx) Rebind(query string) string {
return Rebind(BindType(tx.driverName), query)
}
// Unsafe returns a version of Tx which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
func (tx *Tx) Unsafe() *Tx {
return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper}
}
// BindNamed binds a query within a transaction's bindvar type.
func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper)
}
// NamedQuery within a transaction.
func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) {
return NamedQuery(tx, query, arg)
}
// NamedExec a named query within a transaction.
func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) {
return NamedExec(tx, query, arg)
}
// Select within a transaction.
func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error {
return Select(tx, dest, query, args...)
}
// Queryx within a transaction.
func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := tx.Tx.Query(query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
}
// QueryRowx within a transaction.
func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {
rows, err := tx.Tx.Query(query, args...)
return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
}
// Get within a transaction.
func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error {
return Get(tx, dest, query, args...)
}
// MustExec runs MustExec within a transaction.
func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result {
return MustExec(tx, query, args...)
}
// Preparex a statement within a transaction.
func (tx *Tx) Preparex(query string) (*Stmt, error) {
return Preparex(tx, query)
}
// Stmtx returns a version of the prepared statement which runs within a transaction. Provided
// stmt can be either *sql.Stmt or *sqlx.Stmt.
func (tx *Tx) Stmtx(stmt interface{}) *Stmt {
var s *sql.Stmt
switch v := stmt.(type) {
case Stmt:
s = v.Stmt
case *Stmt:
s = v.Stmt
case sql.Stmt:
s = &v
case *sql.Stmt:
s = v
default:
panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type()))
}
return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper}
}
// NamedStmt returns a version of the prepared statement which runs within a transaction.
func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt {
return &NamedStmt{
QueryString: stmt.QueryString,
Params: stmt.Params,
Stmt: tx.Stmtx(stmt.Stmt),
}
}
// PrepareNamed returns an sqlx.NamedStmt
func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) {
return prepareNamed(tx, query)
}
// Stmt is an sqlx wrapper around sql.Stmt with extra functionality
type Stmt struct {
*sql.Stmt
unsafe bool
Mapper *reflectx.Mapper
}
// Unsafe returns a version of Stmt which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
func (s *Stmt) Unsafe() *Stmt {
return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper}
}
// Select using the prepared statement.
func (s *Stmt) Select(dest interface{}, args ...interface{}) error {
return Select(&qStmt{s}, dest, "", args...)
}
// Get using the prepared statement.
func (s *Stmt) Get(dest interface{}, args ...interface{}) error {
return Get(&qStmt{s}, dest, "", args...)
}
// MustExec (panic) using this statement. Note that the query portion of the error
// output will be blank, as Stmt does not expose its query.
func (s *Stmt) MustExec(args ...interface{}) sql.Result {
return MustExec(&qStmt{s}, "", args...)
}
// QueryRowx using this statement.
func (s *Stmt) QueryRowx(args ...interface{}) *Row {
qs := &qStmt{s}
return qs.QueryRowx("", args...)
}
// Queryx using this statement.
func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) {
qs := &qStmt{s}
return qs.Queryx("", args...)
}
// qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by
// implementing those interfaces and ignoring the `query` argument.
type qStmt struct{ *Stmt }
func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) {
return q.Stmt.Query(args...)
}
func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := q.Stmt.Query(args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
}
func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row {
rows, err := q.Stmt.Query(args...)
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}
}
func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) {
return q.Stmt.Exec(args...)
}
// Rows is a wrapper around sql.Rows which caches costly reflect operations
// during a looped StructScan
type Rows struct {
*sql.Rows
unsafe bool
Mapper *reflectx.Mapper
// these fields cache memory use for a rows during iteration w/ structScan
started bool
fields [][]int
values []interface{}
}
// SliceScan using this Rows.
func (r *Rows) SliceScan() ([]interface{}, error) {
return SliceScan(r)
}
// MapScan using this Rows.
func (r *Rows) MapScan(dest map[string]interface{}) error {
return MapScan(r, dest)
}
// StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct.
// Use this and iterate over Rows manually when the memory load of Select() might be
// prohibitive. *Rows.StructScan caches the reflect work of matching up column
// positions to fields to avoid that overhead per scan, which means it is not safe
// to run StructScan on the same Rows instance with different struct types.
func (r *Rows) StructScan(dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
v = reflect.Indirect(v)
if !r.started {
columns, err := r.Columns()
if err != nil {
return err
}
m := r.Mapper
r.fields = m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(r.fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s", columns[f])
}
r.values = make([]interface{}, len(columns))
r.started = true
}
err := fieldsByTraversal(v, r.fields, r.values, true)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r.Scan(r.values...)
if err != nil {
return err
}
return r.Err()
}
// Connect to a database and verify with a ping.
func Connect(driverName, dataSourceName string) (*DB, error) {
db, err := Open(driverName, dataSourceName)
if err != nil {
return db, err
}
err = db.Ping()
return db, err
}
// MustConnect connects to a database and panics on error.
func MustConnect(driverName, dataSourceName string) *DB {
db, err := Connect(driverName, dataSourceName)
if err != nil {
panic(err)
}
return db
}
// Preparex prepares a statement.
func Preparex(p Preparer, query string) (*Stmt, error) {
s, err := p.Prepare(query)
if err != nil {
return nil, err
}
return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err
}
// Select executes a query using the provided Queryer, and StructScans each row
// into dest, which must be a slice. If the slice elements are scannable, then
// the result set must have only one column. Otherwise, StructScan is used.
// The *sql.Rows are closed automatically.
func Select(q Queryer, dest interface{}, query string, args ...interface{}) error {
rows, err := q.Queryx(query, args...)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// Get does a QueryRow using the provided Queryer, and scans the resulting row
// to dest. If dest is scannable, the result must only have one column. Otherwise,
// StructScan is used. Get will return sql.ErrNoRows like row.Scan would.
func Get(q Queryer, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowx(query, args...)
return r.scanAny(dest, false)
}
// LoadFile exec's every statement in a file (as a single call to Exec).
// LoadFile may return a nil *sql.Result if errors are encountered locating or
// reading the file at path. LoadFile reads the entire file into memory, so it
// is not suitable for loading large data dumps, but can be useful for initializing
// schemas or loading indexes.
//
// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3
// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting
// this by requiring something with DriverName() and then attempting to split the
// queries will be difficult to get right, and its current driver-specific behavior
// is deemed at least not complex in its incorrectness.
func LoadFile(e Execer, path string) (*sql.Result, error) {
realpath, err := filepath.Abs(path)
if err != nil {
return nil, err
}
contents, err := ioutil.ReadFile(realpath)
if err != nil {
return nil, err
}
res, err := e.Exec(string(contents))
return &res, err
}
// MustExec execs the query using e and panics if there was an error.
func MustExec(e Execer, query string, args ...interface{}) sql.Result {
res, err := e.Exec(query, args...)
if err != nil {
panic(err)
}
return res
}
// SliceScan using this Rows.
func (r *Row) SliceScan() ([]interface{}, error) {
return SliceScan(r)
}
// MapScan using this Rows.
func (r *Row) MapScan(dest map[string]interface{}) error {
return MapScan(r, dest)
}
func (r *Row) scanAny(dest interface{}, structOnly bool) error {
if r.err != nil {
return r.err
}
defer r.rows.Close()
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
if v.IsNil() {
return errors.New("nil pointer passed to StructScan destination")
}
base := reflectx.Deref(v.Type())
scannable := isScannable(base)
if structOnly && scannable {
return structOnlyError(base)
}
columns, err := r.Columns()
if err != nil {
return err
}
if scannable && len(columns) > 1 {
return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns))
}
if scannable {
return r.Scan(dest)
}
m := r.Mapper
fields := m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s", columns[f])
}
values := make([]interface{}, len(columns))
err = fieldsByTraversal(v, fields, values, true)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
return r.Scan(values...)
}
// StructScan a single Row into dest.
func (r *Row) StructScan(dest interface{}) error {
return r.scanAny(dest, true)
}
// SliceScan a row, returning a []interface{} with values similar to MapScan.
// This function is primarly intended for use where the number of columns
// is not known. Because you can pass an []interface{} directly to Scan,
// it's recommended that you do that as it will not have to allocate new
// slices per row.
func SliceScan(r ColScanner) ([]interface{}, error) {
// ignore r.started, since we needn't use reflect for anything.
columns, err := r.Columns()
if err != nil {
return []interface{}{}, err
}
values := make([]interface{}, len(columns))
for i := range values {
values[i] = new(interface{})
}
err = r.Scan(values...)
if err != nil {
return values, err
}
for i := range columns {
values[i] = *(values[i].(*interface{}))
}
return values, r.Err()
}
// MapScan scans a single Row into the dest map[string]interface{}.
// Use this to get results for SQL that might not be under your control
// (for instance, if you're building an interface for an SQL server that
// executes SQL from input). Please do not use this as a primary interface!
// This will modify the map sent to it in place, so reuse the same map with
// care. Columns which occur more than once in the result will overwrite
// eachother!
func MapScan(r ColScanner, dest map[string]interface{}) error {
// ignore r.started, since we needn't use reflect for anything.
columns, err := r.Columns()
if err != nil {
return err
}
values := make([]interface{}, len(columns))
for i := range values {
values[i] = new(interface{})
}
err = r.Scan(values...)
if err != nil {
return err
}
for i, column := range columns {
dest[column] = *(values[i].(*interface{}))
}
return r.Err()
}
type rowsi interface {
Close() error
Columns() ([]string, error)
Err() error
Next() bool
Scan(...interface{}) error
}
// structOnlyError returns an error appropriate for type when a non-scannable
// struct is expected but something else is given
func structOnlyError(t reflect.Type) error {
isStruct := t.Kind() == reflect.Struct
isScanner := reflect.PtrTo(t).Implements(_scannerInterface)
if !isStruct {
return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind())
}
if isScanner {
return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name())
}
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
}
// scanAll scans all rows into a destination, which must be a slice of any
// type. If the destination slice type is a Struct, then StructScan will be
// used on each row. If the destination is some other kind of base type, then
// each row must only have one column which can scan into that type. This
// allows you to do something like:
//
// rows, _ := db.Query("select id from people;")
// var ids []int
// scanAll(rows, &ids, false)
//
// and ids will be a list of the id results. I realize that this is a desirable
// interface to expose to users, but for now it will only be exposed via changes
// to `Get` and `Select`. The reason that this has been implemented like this is
// this is the only way to not duplicate reflect work in the new API while
// maintaining backwards compatibility.
func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
var v, vp reflect.Value
value := reflect.ValueOf(dest)
// json.Unmarshal returns errors for these
if value.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
if value.IsNil() {
return errors.New("nil pointer passed to StructScan destination")
}
direct := reflect.Indirect(value)
slice, err := baseType(value.Type(), reflect.Slice)
if err != nil {
return err
}
isPtr := slice.Elem().Kind() == reflect.Ptr
base := reflectx.Deref(slice.Elem())
scannable := isScannable(base)
if structOnly && scannable {
return structOnlyError(base)
}
columns, err := rows.Columns()
if err != nil {
return err
}
// if it's a base type make sure it only has 1 column; if not return an error
if scannable && len(columns) > 1 {
return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns))
}
if !scannable {
var values []interface{}
var m *reflectx.Mapper
switch rows.(type) {
case *Rows:
m = rows.(*Rows).Mapper
default:
m = mapper()
}
fields := m.TraversalsByName(base, columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil && !isUnsafe(rows) {
return fmt.Errorf("missing destination name %s", columns[f])
}
values = make([]interface{}, len(columns))
for rows.Next() {
// create a new struct type (which returns PtrTo) and indirect it
vp = reflect.New(base)
v = reflect.Indirect(vp)
err = fieldsByTraversal(v, fields, values, true)
// scan into the struct field pointers and append to our results
err = rows.Scan(values...)
if err != nil {
return err
}
if isPtr {
direct.Set(reflect.Append(direct, vp))
} else {
direct.Set(reflect.Append(direct, v))
}
}
} else {
for rows.Next() {
vp = reflect.New(base)
err = rows.Scan(vp.Interface())
// append
if isPtr {
direct.Set(reflect.Append(direct, vp))
} else {
direct.Set(reflect.Append(direct, reflect.Indirect(vp)))
}
}
}
return rows.Err()
}
// FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately
// it doesn't really feel like it's named properly. There is an incongruency
// between this and the way that StructScan (which might better be ScanStruct
// anyway) works on a rows object.
// StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice.
// StructScan will scan in the entire rows result, so if you need do not want to
// allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan.
// If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default.
func StructScan(rows rowsi, dest interface{}) error {
return scanAll(rows, dest, true)
}
// reflect helpers
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
t = reflectx.Deref(t)
if t.Kind() != expected {
return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind())
}
return t, nil
}
// fieldsByName fills a values interface with fields from the passed value based
// on the traversals in int. If ptrs is true, return addresses instead of values.
// We write this instead of using FieldsByName to save allocations and map lookups
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
values[i] = f.Interface()
}
}
return nil
}
func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}

View File

@ -1,191 +0,0 @@
All files in this repository are licensed as follows. If you contribute
to this repository, it is assumed that you license your contribution
under the same license unless you state otherwise.
All files Copyright (C) 2015 Canonical Ltd. unless otherwise specified in the file.
This software is licensed under the LGPLv3, included below.
As a special exception to the GNU Lesser General Public License version 3
("LGPL3"), the copyright holders of this Library give you permission to
convey to a third party a Combined Work that links statically or dynamically
to this Library without providing any Minimal Corresponding Source or
Minimal Application Code as set out in 4d or providing the installation
information set out in section 4e, provided that you comply with the other
provisions of LGPL3 and provided that you meet, for the Application the
terms and conditions of the license(s) which apply to the Application.
Except as stated in this special exception, the provisions of LGPL3 will
continue to comply in full to this Library. If you modify this Library, you
may apply this exception to your version of this Library, but you are not
obliged to do so. If you do not wish to do so, delete this exception
statement from your version. This exception does not (and cannot) modify any
license terms which apply to the Application, with which you must still
comply.
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.

View File

@ -1,81 +0,0 @@
// Copyright 2013, 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
/*
[godoc-link-here]
The juju/errors provides an easy way to annotate errors without losing the
orginal error context.
The exported `New` and `Errorf` functions are designed to replace the
`errors.New` and `fmt.Errorf` functions respectively. The same underlying
error is there, but the package also records the location at which the error
was created.
A primary use case for this library is to add extra context any time an
error is returned from a function.
if err := SomeFunc(); err != nil {
return err
}
This instead becomes:
if err := SomeFunc(); err != nil {
return errors.Trace(err)
}
which just records the file and line number of the Trace call, or
if err := SomeFunc(); err != nil {
return errors.Annotate(err, "more context")
}
which also adds an annotation to the error.
When you want to check to see if an error is of a particular type, a helper
function is normally exported by the package that returned the error, like the
`os` package does. The underlying cause of the error is available using the
`Cause` function.
os.IsNotExist(errors.Cause(err))
The result of the `Error()` call on an annotated error is the annotations joined
with colons, then the result of the `Error()` method for the underlying error
that was the cause.
err := errors.Errorf("original")
err = errors.Annotatef(err, "context")
err = errors.Annotatef(err, "more context")
err.Error() -> "more context: context: original"
Obviously recording the file, line and functions is not very useful if you
cannot get them back out again.
errors.ErrorStack(err)
will return something like:
first error
github.com/juju/errors/annotation_test.go:193:
github.com/juju/errors/annotation_test.go:194: annotation
github.com/juju/errors/annotation_test.go:195:
github.com/juju/errors/annotation_test.go:196: more context
github.com/juju/errors/annotation_test.go:197:
The first error was generated by an external system, so there was no location
associated. The second, fourth, and last lines were generated with Trace calls,
and the other two through Annotate.
Sometimes when responding to an error you want to return a more specific error
for the situation.
if err := FindField(field); err != nil {
return errors.Wrap(err, errors.NotFoundf(field))
}
This returns an error where the complete error stack is still available, and
`errors.Cause()` will return the `NotFound` error.
*/
package errors

View File

@ -1,145 +0,0 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"fmt"
"reflect"
"runtime"
)
// Err holds a description of an error along with information about
// where the error was created.
//
// It may be embedded in custom error types to add extra information that
// this errors package can understand.
type Err struct {
// message holds an annotation of the error.
message string
// cause holds the cause of the error as returned
// by the Cause method.
cause error
// previous holds the previous error in the error stack, if any.
previous error
// file and line hold the source code location where the error was
// created.
file string
line int
}
// NewErr is used to return an Err for the purpose of embedding in other
// structures. The location is not specified, and needs to be set with a call
// to SetLocation.
//
// For example:
// type FooError struct {
// errors.Err
// code int
// }
//
// func NewFooError(code int) error {
// err := &FooError{errors.NewErr("foo"), code}
// err.SetLocation(1)
// return err
// }
func NewErr(format string, args ...interface{}) Err {
return Err{
message: fmt.Sprintf(format, args...),
}
}
// NewErrWithCause is used to return an Err with case by other error for the purpose of embedding in other
// structures. The location is not specified, and needs to be set with a call
// to SetLocation.
//
// For example:
// type FooError struct {
// errors.Err
// code int
// }
//
// func (e *FooError) Annotate(format string, args ...interface{}) error {
// err := &FooError{errors.NewErrWithCause(e.Err, format, args...), e.code}
// err.SetLocation(1)
// return err
// })
func NewErrWithCause(other error, format string, args ...interface{}) Err {
return Err{
message: fmt.Sprintf(format, args...),
cause: Cause(other),
previous: other,
}
}
// Location is the file and line of where the error was most recently
// created or annotated.
func (e *Err) Location() (filename string, line int) {
return e.file, e.line
}
// Underlying returns the previous error in the error stack, if any. A client
// should not ever really call this method. It is used to build the error
// stack and should not be introspected by client calls. Or more
// specifically, clients should not depend on anything but the `Cause` of an
// error.
func (e *Err) Underlying() error {
return e.previous
}
// The Cause of an error is the most recent error in the error stack that
// meets one of these criteria: the original error that was raised; the new
// error that was passed into the Wrap function; the most recently masked
// error; or nil if the error itself is considered the Cause. Normally this
// method is not invoked directly, but instead through the Cause stand alone
// function.
func (e *Err) Cause() error {
return e.cause
}
// Message returns the message stored with the most recent location. This is
// the empty string if the most recent call was Trace, or the message stored
// with Annotate or Mask.
func (e *Err) Message() string {
return e.message
}
// Error implements error.Error.
func (e *Err) Error() string {
// We want to walk up the stack of errors showing the annotations
// as long as the cause is the same.
err := e.previous
if !sameError(Cause(err), e.cause) && e.cause != nil {
err = e.cause
}
switch {
case err == nil:
return e.message
case e.message == "":
return err.Error()
}
return fmt.Sprintf("%s: %v", e.message, err)
}
// SetLocation records the source location of the error at callDepth stack
// frames above the call.
func (e *Err) SetLocation(callDepth int) {
_, file, line, _ := runtime.Caller(callDepth + 1)
e.file = trimGoPath(file)
e.line = line
}
// StackTrace returns one string for each location recorded in the stack of
// errors. The first value is the originating error, with a line for each
// other annotation or tracing of the error.
func (e *Err) StackTrace() []string {
return errorStack(e)
}
// Ideally we'd have a way to check identity, but deep equals will do.
func sameError(e1, e2 error) bool {
return reflect.DeepEqual(e1, e2)
}

View File

@ -1,284 +0,0 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"fmt"
)
// wrap is a helper to construct an *wrapper.
func wrap(err error, format, suffix string, args ...interface{}) Err {
newErr := Err{
message: fmt.Sprintf(format+suffix, args...),
previous: err,
}
newErr.SetLocation(2)
return newErr
}
// notFound represents an error when something has not been found.
type notFound struct {
Err
}
// NotFoundf returns an error which satisfies IsNotFound().
func NotFoundf(format string, args ...interface{}) error {
return &notFound{wrap(nil, format, " not found", args...)}
}
// NewNotFound returns an error which wraps err that satisfies
// IsNotFound().
func NewNotFound(err error, msg string) error {
return &notFound{wrap(err, msg, "")}
}
// IsNotFound reports whether err was created with NotFoundf() or
// NewNotFound().
func IsNotFound(err error) bool {
err = Cause(err)
_, ok := err.(*notFound)
return ok
}
// userNotFound represents an error when an inexistent user is looked up.
type userNotFound struct {
Err
}
// UserNotFoundf returns an error which satisfies IsUserNotFound().
func UserNotFoundf(format string, args ...interface{}) error {
return &userNotFound{wrap(nil, format, " user not found", args...)}
}
// NewUserNotFound returns an error which wraps err and satisfies
// IsUserNotFound().
func NewUserNotFound(err error, msg string) error {
return &userNotFound{wrap(err, msg, "")}
}
// IsUserNotFound reports whether err was created with UserNotFoundf() or
// NewUserNotFound().
func IsUserNotFound(err error) bool {
err = Cause(err)
_, ok := err.(*userNotFound)
return ok
}
// unauthorized represents an error when an operation is unauthorized.
type unauthorized struct {
Err
}
// Unauthorizedf returns an error which satisfies IsUnauthorized().
func Unauthorizedf(format string, args ...interface{}) error {
return &unauthorized{wrap(nil, format, "", args...)}
}
// NewUnauthorized returns an error which wraps err and satisfies
// IsUnauthorized().
func NewUnauthorized(err error, msg string) error {
return &unauthorized{wrap(err, msg, "")}
}
// IsUnauthorized reports whether err was created with Unauthorizedf() or
// NewUnauthorized().
func IsUnauthorized(err error) bool {
err = Cause(err)
_, ok := err.(*unauthorized)
return ok
}
// notImplemented represents an error when something is not
// implemented.
type notImplemented struct {
Err
}
// NotImplementedf returns an error which satisfies IsNotImplemented().
func NotImplementedf(format string, args ...interface{}) error {
return &notImplemented{wrap(nil, format, " not implemented", args...)}
}
// NewNotImplemented returns an error which wraps err and satisfies
// IsNotImplemented().
func NewNotImplemented(err error, msg string) error {
return &notImplemented{wrap(err, msg, "")}
}
// IsNotImplemented reports whether err was created with
// NotImplementedf() or NewNotImplemented().
func IsNotImplemented(err error) bool {
err = Cause(err)
_, ok := err.(*notImplemented)
return ok
}
// alreadyExists represents and error when something already exists.
type alreadyExists struct {
Err
}
// AlreadyExistsf returns an error which satisfies IsAlreadyExists().
func AlreadyExistsf(format string, args ...interface{}) error {
return &alreadyExists{wrap(nil, format, " already exists", args...)}
}
// NewAlreadyExists returns an error which wraps err and satisfies
// IsAlreadyExists().
func NewAlreadyExists(err error, msg string) error {
return &alreadyExists{wrap(err, msg, "")}
}
// IsAlreadyExists reports whether the error was created with
// AlreadyExistsf() or NewAlreadyExists().
func IsAlreadyExists(err error) bool {
err = Cause(err)
_, ok := err.(*alreadyExists)
return ok
}
// notSupported represents an error when something is not supported.
type notSupported struct {
Err
}
// NotSupportedf returns an error which satisfies IsNotSupported().
func NotSupportedf(format string, args ...interface{}) error {
return &notSupported{wrap(nil, format, " not supported", args...)}
}
// NewNotSupported returns an error which wraps err and satisfies
// IsNotSupported().
func NewNotSupported(err error, msg string) error {
return &notSupported{wrap(err, msg, "")}
}
// IsNotSupported reports whether the error was created with
// NotSupportedf() or NewNotSupported().
func IsNotSupported(err error) bool {
err = Cause(err)
_, ok := err.(*notSupported)
return ok
}
// notValid represents an error when something is not valid.
type notValid struct {
Err
}
// NotValidf returns an error which satisfies IsNotValid().
func NotValidf(format string, args ...interface{}) error {
return &notValid{wrap(nil, format, " not valid", args...)}
}
// NewNotValid returns an error which wraps err and satisfies IsNotValid().
func NewNotValid(err error, msg string) error {
return &notValid{wrap(err, msg, "")}
}
// IsNotValid reports whether the error was created with NotValidf() or
// NewNotValid().
func IsNotValid(err error) bool {
err = Cause(err)
_, ok := err.(*notValid)
return ok
}
// notProvisioned represents an error when something is not yet provisioned.
type notProvisioned struct {
Err
}
// NotProvisionedf returns an error which satisfies IsNotProvisioned().
func NotProvisionedf(format string, args ...interface{}) error {
return &notProvisioned{wrap(nil, format, " not provisioned", args...)}
}
// NewNotProvisioned returns an error which wraps err that satisfies
// IsNotProvisioned().
func NewNotProvisioned(err error, msg string) error {
return &notProvisioned{wrap(err, msg, "")}
}
// IsNotProvisioned reports whether err was created with NotProvisionedf() or
// NewNotProvisioned().
func IsNotProvisioned(err error) bool {
err = Cause(err)
_, ok := err.(*notProvisioned)
return ok
}
// notAssigned represents an error when something is not yet assigned to
// something else.
type notAssigned struct {
Err
}
// NotAssignedf returns an error which satisfies IsNotAssigned().
func NotAssignedf(format string, args ...interface{}) error {
return &notAssigned{wrap(nil, format, " not assigned", args...)}
}
// NewNotAssigned returns an error which wraps err that satisfies
// IsNotAssigned().
func NewNotAssigned(err error, msg string) error {
return &notAssigned{wrap(err, msg, "")}
}
// IsNotAssigned reports whether err was created with NotAssignedf() or
// NewNotAssigned().
func IsNotAssigned(err error) bool {
err = Cause(err)
_, ok := err.(*notAssigned)
return ok
}
// badRequest represents an error when a request has bad parameters.
type badRequest struct {
Err
}
// BadRequestf returns an error which satisfies IsBadRequest().
func BadRequestf(format string, args ...interface{}) error {
return &badRequest{wrap(nil, format, "", args...)}
}
// NewBadRequest returns an error which wraps err that satisfies
// IsBadRequest().
func NewBadRequest(err error, msg string) error {
return &badRequest{wrap(err, msg, "")}
}
// IsBadRequest reports whether err was created with BadRequestf() or
// NewBadRequest().
func IsBadRequest(err error) bool {
err = Cause(err)
_, ok := err.(*badRequest)
return ok
}
// methodNotAllowed represents an error when an HTTP request
// is made with an inappropriate method.
type methodNotAllowed struct {
Err
}
// MethodNotAllowedf returns an error which satisfies IsMethodNotAllowed().
func MethodNotAllowedf(format string, args ...interface{}) error {
return &methodNotAllowed{wrap(nil, format, "", args...)}
}
// NewMethodNotAllowed returns an error which wraps err that satisfies
// IsMethodNotAllowed().
func NewMethodNotAllowed(err error, msg string) error {
return &methodNotAllowed{wrap(err, msg, "")}
}
// IsMethodNotAllowed reports whether err was created with MethodNotAllowedf() or
// NewMethodNotAllowed().
func IsMethodNotAllowed(err error) bool {
err = Cause(err)
_, ok := err.(*methodNotAllowed)
return ok
}

View File

@ -1,330 +0,0 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"fmt"
"strings"
)
// New is a drop in replacement for the standard libary errors module that records
// the location that the error is created.
//
// For example:
// return errors.New("validation failed")
//
func New(message string) error {
err := &Err{message: message}
err.SetLocation(1)
return err
}
// Errorf creates a new annotated error and records the location that the
// error is created. This should be a drop in replacement for fmt.Errorf.
//
// For example:
// return errors.Errorf("validation failed: %s", message)
//
func Errorf(format string, args ...interface{}) error {
err := &Err{message: fmt.Sprintf(format, args...)}
err.SetLocation(1)
return err
}
// Trace adds the location of the Trace call to the stack. The Cause of the
// resulting error is the same as the error parameter. If the other error is
// nil, the result will be nil.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Trace(err)
// }
//
func Trace(other error) error {
if other == nil {
return nil
}
err := &Err{previous: other, cause: Cause(other)}
err.SetLocation(1)
return err
}
// Annotate is used to add extra context to an existing error. The location of
// the Annotate call is recorded with the annotations. The file, line and
// function are also recorded.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Annotate(err, "failed to frombulate")
// }
//
func Annotate(other error, message string) error {
if other == nil {
return nil
}
err := &Err{
previous: other,
cause: Cause(other),
message: message,
}
err.SetLocation(1)
return err
}
// Annotatef is used to add extra context to an existing error. The location of
// the Annotate call is recorded with the annotations. The file, line and
// function are also recorded.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Annotatef(err, "failed to frombulate the %s", arg)
// }
//
func Annotatef(other error, format string, args ...interface{}) error {
if other == nil {
return nil
}
err := &Err{
previous: other,
cause: Cause(other),
message: fmt.Sprintf(format, args...),
}
err.SetLocation(1)
return err
}
// DeferredAnnotatef annotates the given error (when it is not nil) with the given
// format string and arguments (like fmt.Sprintf). If *err is nil, DeferredAnnotatef
// does nothing. This method is used in a defer statement in order to annotate any
// resulting error with the same message.
//
// For example:
//
// defer DeferredAnnotatef(&err, "failed to frombulate the %s", arg)
//
func DeferredAnnotatef(err *error, format string, args ...interface{}) {
if *err == nil {
return
}
newErr := &Err{
message: fmt.Sprintf(format, args...),
cause: Cause(*err),
previous: *err,
}
newErr.SetLocation(1)
*err = newErr
}
// Wrap changes the Cause of the error. The location of the Wrap call is also
// stored in the error stack.
//
// For example:
// if err := SomeFunc(); err != nil {
// newErr := &packageError{"more context", private_value}
// return errors.Wrap(err, newErr)
// }
//
func Wrap(other, newDescriptive error) error {
err := &Err{
previous: other,
cause: newDescriptive,
}
err.SetLocation(1)
return err
}
// Wrapf changes the Cause of the error, and adds an annotation. The location
// of the Wrap call is also stored in the error stack.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Wrapf(err, simpleErrorType, "invalid value %q", value)
// }
//
func Wrapf(other, newDescriptive error, format string, args ...interface{}) error {
err := &Err{
message: fmt.Sprintf(format, args...),
previous: other,
cause: newDescriptive,
}
err.SetLocation(1)
return err
}
// Mask masks the given error with the given format string and arguments (like
// fmt.Sprintf), returning a new error that maintains the error stack, but
// hides the underlying error type. The error string still contains the full
// annotations. If you want to hide the annotations, call Wrap.
func Maskf(other error, format string, args ...interface{}) error {
if other == nil {
return nil
}
err := &Err{
message: fmt.Sprintf(format, args...),
previous: other,
}
err.SetLocation(1)
return err
}
// Mask hides the underlying error type, and records the location of the masking.
func Mask(other error) error {
if other == nil {
return nil
}
err := &Err{
previous: other,
}
err.SetLocation(1)
return err
}
// Cause returns the cause of the given error. This will be either the
// original error, or the result of a Wrap or Mask call.
//
// Cause is the usual way to diagnose errors that may have been wrapped by
// the other errors functions.
func Cause(err error) error {
var diag error
if err, ok := err.(causer); ok {
diag = err.Cause()
}
if diag != nil {
return diag
}
return err
}
type causer interface {
Cause() error
}
type wrapper interface {
// Message returns the top level error message,
// not including the message from the Previous
// error.
Message() string
// Underlying returns the Previous error, or nil
// if there is none.
Underlying() error
}
type locationer interface {
Location() (string, int)
}
var (
_ wrapper = (*Err)(nil)
_ locationer = (*Err)(nil)
_ causer = (*Err)(nil)
)
// Details returns information about the stack of errors wrapped by err, in
// the format:
//
// [{filename:99: error one} {otherfile:55: cause of error one}]
//
// This is a terse alternative to ErrorStack as it returns a single line.
func Details(err error) string {
if err == nil {
return "[]"
}
var s []byte
s = append(s, '[')
for {
s = append(s, '{')
if err, ok := err.(locationer); ok {
file, line := err.Location()
if file != "" {
s = append(s, fmt.Sprintf("%s:%d", file, line)...)
s = append(s, ": "...)
}
}
if cerr, ok := err.(wrapper); ok {
s = append(s, cerr.Message()...)
err = cerr.Underlying()
} else {
s = append(s, err.Error()...)
err = nil
}
s = append(s, '}')
if err == nil {
break
}
s = append(s, ' ')
}
s = append(s, ']')
return string(s)
}
// ErrorStack returns a string representation of the annotated error. If the
// error passed as the parameter is not an annotated error, the result is
// simply the result of the Error() method on that error.
//
// If the error is an annotated error, a multi-line string is returned where
// each line represents one entry in the annotation stack. The full filename
// from the call stack is used in the output.
//
// first error
// github.com/juju/errors/annotation_test.go:193:
// github.com/juju/errors/annotation_test.go:194: annotation
// github.com/juju/errors/annotation_test.go:195:
// github.com/juju/errors/annotation_test.go:196: more context
// github.com/juju/errors/annotation_test.go:197:
func ErrorStack(err error) string {
return strings.Join(errorStack(err), "\n")
}
func errorStack(err error) []string {
if err == nil {
return nil
}
// We want the first error first
var lines []string
for {
var buff []byte
if err, ok := err.(locationer); ok {
file, line := err.Location()
// Strip off the leading GOPATH/src path elements.
file = trimGoPath(file)
if file != "" {
buff = append(buff, fmt.Sprintf("%s:%d", file, line)...)
buff = append(buff, ": "...)
}
}
if cerr, ok := err.(wrapper); ok {
message := cerr.Message()
buff = append(buff, message...)
// If there is a cause for this error, and it is different to the cause
// of the underlying error, then output the error string in the stack trace.
var cause error
if err1, ok := err.(causer); ok {
cause = err1.Cause()
}
err = cerr.Underlying()
if cause != nil && !sameError(Cause(err), cause) {
if message != "" {
buff = append(buff, ": "...)
}
buff = append(buff, cause.Error()...)
}
} else {
buff = append(buff, err.Error()...)
err = nil
}
lines = append(lines, string(buff))
if err == nil {
break
}
}
// reverse the lines to get the original error, which was at the end of
// the list, back to the start.
var result []string
for i := len(lines); i > 0; i-- {
result = append(result, lines[i-1])
}
return result
}

View File

@ -1,38 +0,0 @@
// Copyright 2013, 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"runtime"
"strings"
)
// prefixSize is used internally to trim the user specific path from the
// front of the returned filenames from the runtime call stack.
var prefixSize int
// goPath is the deduced path based on the location of this file as compiled.
var goPath string
func init() {
_, file, _, ok := runtime.Caller(0)
if file == "?" {
return
}
if ok {
// We know that the end of the file should be:
// github.com/juju/errors/path.go
size := len(file)
suffix := len("github.com/juju/errors/path.go")
goPath = file[:size-suffix]
prefixSize = len(goPath)
}
}
func trimGoPath(filename string) string {
if strings.HasPrefix(filename, goPath) {
return filename[prefixSize:]
}
return filename
}

View File

@ -1,165 +0,0 @@
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.

View File

@ -1,18 +0,0 @@
// +build freebsd openbsd netbsd dragonfly darwin linux
package log
import (
"log"
"os"
"syscall"
)
func CrashLog(file string) {
f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
log.Println(err.Error())
} else {
syscall.Dup2(int(f.Fd()), 2)
}
}

View File

@ -1,37 +0,0 @@
// +build windows
package log
import (
"log"
"os"
"syscall"
)
var (
kernel32 = syscall.MustLoadDLL("kernel32.dll")
procSetStdHandle = kernel32.MustFindProc("SetStdHandle")
)
func setStdHandle(stdhandle int32, handle syscall.Handle) error {
r0, _, e1 := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdhandle), uintptr(handle), 0)
if r0 == 0 {
if e1 != 0 {
return error(e1)
}
return syscall.EINVAL
}
return nil
}
func CrashLog(file string) {
f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
log.Println(err.Error())
} else {
err = setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(f.Fd()))
if err != nil {
log.Println(err.Error())
}
}
}

View File

@ -1,380 +0,0 @@
//high level log wrapper, so it can output different log based on level
package log
import (
"fmt"
"io"
"log"
"os"
"runtime"
"sync"
"time"
)
const (
Ldate = log.Ldate
Llongfile = log.Llongfile
Lmicroseconds = log.Lmicroseconds
Lshortfile = log.Lshortfile
LstdFlags = log.LstdFlags
Ltime = log.Ltime
)
type (
LogLevel int
LogType int
)
const (
LOG_FATAL = LogType(0x1)
LOG_ERROR = LogType(0x2)
LOG_WARNING = LogType(0x4)
LOG_INFO = LogType(0x8)
LOG_DEBUG = LogType(0x10)
)
const (
LOG_LEVEL_NONE = LogLevel(0x0)
LOG_LEVEL_FATAL = LOG_LEVEL_NONE | LogLevel(LOG_FATAL)
LOG_LEVEL_ERROR = LOG_LEVEL_FATAL | LogLevel(LOG_ERROR)
LOG_LEVEL_WARN = LOG_LEVEL_ERROR | LogLevel(LOG_WARNING)
LOG_LEVEL_INFO = LOG_LEVEL_WARN | LogLevel(LOG_INFO)
LOG_LEVEL_DEBUG = LOG_LEVEL_INFO | LogLevel(LOG_DEBUG)
LOG_LEVEL_ALL = LOG_LEVEL_DEBUG
)
const FORMAT_TIME_DAY string = "20060102"
const FORMAT_TIME_HOUR string = "2006010215"
var _log *logger = New()
func init() {
SetFlags(Ldate | Ltime | Lshortfile)
SetHighlighting(runtime.GOOS != "windows")
}
func Logger() *log.Logger {
return _log._log
}
func SetLevel(level LogLevel) {
_log.SetLevel(level)
}
func GetLogLevel() LogLevel {
return _log.level
}
func SetOutput(out io.Writer) {
_log.SetOutput(out)
}
func SetOutputByName(path string) error {
return _log.SetOutputByName(path)
}
func SetFlags(flags int) {
_log._log.SetFlags(flags)
}
func Info(v ...interface{}) {
_log.Info(v...)
}
func Infof(format string, v ...interface{}) {
_log.Infof(format, v...)
}
func Debug(v ...interface{}) {
_log.Debug(v...)
}
func Debugf(format string, v ...interface{}) {
_log.Debugf(format, v...)
}
func Warn(v ...interface{}) {
_log.Warning(v...)
}
func Warnf(format string, v ...interface{}) {
_log.Warningf(format, v...)
}
func Warning(v ...interface{}) {
_log.Warning(v...)
}
func Warningf(format string, v ...interface{}) {
_log.Warningf(format, v...)
}
func Error(v ...interface{}) {
_log.Error(v...)
}
func Errorf(format string, v ...interface{}) {
_log.Errorf(format, v...)
}
func Fatal(v ...interface{}) {
_log.Fatal(v...)
}
func Fatalf(format string, v ...interface{}) {
_log.Fatalf(format, v...)
}
func SetLevelByString(level string) {
_log.SetLevelByString(level)
}
func SetHighlighting(highlighting bool) {
_log.SetHighlighting(highlighting)
}
func SetRotateByDay() {
_log.SetRotateByDay()
}
func SetRotateByHour() {
_log.SetRotateByHour()
}
type logger struct {
_log *log.Logger
level LogLevel
highlighting bool
dailyRolling bool
hourRolling bool
fileName string
logSuffix string
fd *os.File
lock sync.Mutex
}
func (l *logger) SetHighlighting(highlighting bool) {
l.highlighting = highlighting
}
func (l *logger) SetLevel(level LogLevel) {
l.level = level
}
func (l *logger) SetLevelByString(level string) {
l.level = StringToLogLevel(level)
}
func (l *logger) SetRotateByDay() {
l.dailyRolling = true
l.logSuffix = genDayTime(time.Now())
}
func (l *logger) SetRotateByHour() {
l.hourRolling = true
l.logSuffix = genHourTime(time.Now())
}
func (l *logger) rotate() error {
l.lock.Lock()
defer l.lock.Unlock()
var suffix string
if l.dailyRolling {
suffix = genDayTime(time.Now())
} else if l.hourRolling {
suffix = genHourTime(time.Now())
} else {
return nil
}
// Notice: if suffix is not equal to l.LogSuffix, then rotate
if suffix != l.logSuffix {
err := l.doRotate(suffix)
if err != nil {
return err
}
}
return nil
}
func (l *logger) doRotate(suffix string) error {
// Notice: Not check error, is this ok?
l.fd.Close()
lastFileName := l.fileName + "." + l.logSuffix
err := os.Rename(l.fileName, lastFileName)
if err != nil {
return err
}
err = l.SetOutputByName(l.fileName)
if err != nil {
return err
}
l.logSuffix = suffix
return nil
}
func (l *logger) SetOutput(out io.Writer) {
l._log = log.New(out, l._log.Prefix(), l._log.Flags())
}
func (l *logger) SetOutputByName(path string) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666)
if err != nil {
log.Fatal(err)
}
l.SetOutput(f)
l.fileName = path
l.fd = f
return err
}
func (l *logger) log(t LogType, v ...interface{}) {
if l.level|LogLevel(t) != l.level {
return
}
err := l.rotate()
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
return
}
v1 := make([]interface{}, len(v)+2)
logStr, logColor := LogTypeToString(t)
if l.highlighting {
v1[0] = "\033" + logColor + "m[" + logStr + "]"
copy(v1[1:], v)
v1[len(v)+1] = "\033[0m"
} else {
v1[0] = "[" + logStr + "]"
copy(v1[1:], v)
v1[len(v)+1] = ""
}
s := fmt.Sprintln(v1...)
l._log.Output(4, s)
}
func (l *logger) logf(t LogType, format string, v ...interface{}) {
if l.level|LogLevel(t) != l.level {
return
}
err := l.rotate()
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
return
}
logStr, logColor := LogTypeToString(t)
var s string
if l.highlighting {
s = "\033" + logColor + "m[" + logStr + "] " + fmt.Sprintf(format, v...) + "\033[0m"
} else {
s = "[" + logStr + "] " + fmt.Sprintf(format, v...)
}
l._log.Output(4, s)
}
func (l *logger) Fatal(v ...interface{}) {
l.log(LOG_FATAL, v...)
os.Exit(-1)
}
func (l *logger) Fatalf(format string, v ...interface{}) {
l.logf(LOG_FATAL, format, v...)
os.Exit(-1)
}
func (l *logger) Error(v ...interface{}) {
l.log(LOG_ERROR, v...)
}
func (l *logger) Errorf(format string, v ...interface{}) {
l.logf(LOG_ERROR, format, v...)
}
func (l *logger) Warning(v ...interface{}) {
l.log(LOG_WARNING, v...)
}
func (l *logger) Warningf(format string, v ...interface{}) {
l.logf(LOG_WARNING, format, v...)
}
func (l *logger) Debug(v ...interface{}) {
l.log(LOG_DEBUG, v...)
}
func (l *logger) Debugf(format string, v ...interface{}) {
l.logf(LOG_DEBUG, format, v...)
}
func (l *logger) Info(v ...interface{}) {
l.log(LOG_INFO, v...)
}
func (l *logger) Infof(format string, v ...interface{}) {
l.logf(LOG_INFO, format, v...)
}
func StringToLogLevel(level string) LogLevel {
switch level {
case "fatal":
return LOG_LEVEL_FATAL
case "error":
return LOG_LEVEL_ERROR
case "warn":
return LOG_LEVEL_WARN
case "warning":
return LOG_LEVEL_WARN
case "debug":
return LOG_LEVEL_DEBUG
case "info":
return LOG_LEVEL_INFO
}
return LOG_LEVEL_ALL
}
func LogTypeToString(t LogType) (string, string) {
switch t {
case LOG_FATAL:
return "fatal", "[0;31"
case LOG_ERROR:
return "error", "[0;31"
case LOG_WARNING:
return "warning", "[0;33"
case LOG_DEBUG:
return "debug", "[0;36"
case LOG_INFO:
return "info", "[0;37"
}
return "unknown", "[0;37"
}
func genDayTime(t time.Time) string {
return t.Format(FORMAT_TIME_DAY)
}
func genHourTime(t time.Time) string {
return t.Format(FORMAT_TIME_HOUR)
}
func New() *logger {
return Newlogger(os.Stderr, "")
}
func Newlogger(w io.Writer, prefix string) *logger {
return &logger{_log: log.New(w, prefix, LstdFlags), level: LOG_LEVEL_ALL, highlighting: true}
}

View File

@ -1,187 +0,0 @@
// Copyright (c) 2012 The Go Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package check
import (
"fmt"
"runtime"
"time"
)
var memStats runtime.MemStats
// testingB is a type passed to Benchmark functions to manage benchmark
// timing and to specify the number of iterations to run.
type timer struct {
start time.Time // Time test or benchmark started
duration time.Duration
N int
bytes int64
timerOn bool
benchTime time.Duration
// The initial states of memStats.Mallocs and memStats.TotalAlloc.
startAllocs uint64
startBytes uint64
// The net total of this test after being run.
netAllocs uint64
netBytes uint64
}
// StartTimer starts timing a test. This function is called automatically
// before a benchmark starts, but it can also used to resume timing after
// a call to StopTimer.
func (c *C) StartTimer() {
if !c.timerOn {
c.start = time.Now()
c.timerOn = true
runtime.ReadMemStats(&memStats)
c.startAllocs = memStats.Mallocs
c.startBytes = memStats.TotalAlloc
}
}
// StopTimer stops timing a test. This can be used to pause the timer
// while performing complex initialization that you don't
// want to measure.
func (c *C) StopTimer() {
if c.timerOn {
c.duration += time.Now().Sub(c.start)
c.timerOn = false
runtime.ReadMemStats(&memStats)
c.netAllocs += memStats.Mallocs - c.startAllocs
c.netBytes += memStats.TotalAlloc - c.startBytes
}
}
// ResetTimer sets the elapsed benchmark time to zero.
// It does not affect whether the timer is running.
func (c *C) ResetTimer() {
if c.timerOn {
c.start = time.Now()
runtime.ReadMemStats(&memStats)
c.startAllocs = memStats.Mallocs
c.startBytes = memStats.TotalAlloc
}
c.duration = 0
c.netAllocs = 0
c.netBytes = 0
}
// SetBytes informs the number of bytes that the benchmark processes
// on each iteration. If this is called in a benchmark it will also
// report MB/s.
func (c *C) SetBytes(n int64) {
c.bytes = n
}
func (c *C) nsPerOp() int64 {
if c.N <= 0 {
return 0
}
return c.duration.Nanoseconds() / int64(c.N)
}
func (c *C) mbPerSec() float64 {
if c.bytes <= 0 || c.duration <= 0 || c.N <= 0 {
return 0
}
return (float64(c.bytes) * float64(c.N) / 1e6) / c.duration.Seconds()
}
func (c *C) timerString() string {
if c.N <= 0 {
return fmt.Sprintf("%3.3fs", float64(c.duration.Nanoseconds())/1e9)
}
mbs := c.mbPerSec()
mb := ""
if mbs != 0 {
mb = fmt.Sprintf("\t%7.2f MB/s", mbs)
}
nsop := c.nsPerOp()
ns := fmt.Sprintf("%10d ns/op", nsop)
if c.N > 0 && nsop < 100 {
// The format specifiers here make sure that
// the ones digits line up for all three possible formats.
if nsop < 10 {
ns = fmt.Sprintf("%13.2f ns/op", float64(c.duration.Nanoseconds())/float64(c.N))
} else {
ns = fmt.Sprintf("%12.1f ns/op", float64(c.duration.Nanoseconds())/float64(c.N))
}
}
memStats := ""
if c.benchMem {
allocedBytes := fmt.Sprintf("%8d B/op", int64(c.netBytes)/int64(c.N))
allocs := fmt.Sprintf("%8d allocs/op", int64(c.netAllocs)/int64(c.N))
memStats = fmt.Sprintf("\t%s\t%s", allocedBytes, allocs)
}
return fmt.Sprintf("%8d\t%s%s%s", c.N, ns, mb, memStats)
}
func min(x, y int) int {
if x > y {
return y
}
return x
}
func max(x, y int) int {
if x < y {
return y
}
return x
}
// roundDown10 rounds a number down to the nearest power of 10.
func roundDown10(n int) int {
var tens = 0
// tens = floor(log_10(n))
for n > 10 {
n = n / 10
tens++
}
// result = 10^tens
result := 1
for i := 0; i < tens; i++ {
result *= 10
}
return result
}
// roundUp rounds x up to a number of the form [1eX, 2eX, 5eX].
func roundUp(n int) int {
base := roundDown10(n)
if n < (2 * base) {
return 2 * base
}
if n < (5 * base) {
return 5 * base
}
return 10 * base
}

View File

@ -1,954 +0,0 @@
// Package check is a rich testing extension for Go's testing package.
//
// For details about the project, see:
//
// http://labix.org/gocheck
//
package check
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// -----------------------------------------------------------------------
// Internal type which deals with suite method calling.
const (
fixtureKd = iota
testKd
)
type funcKind int
const (
succeededSt = iota
failedSt
skippedSt
panickedSt
fixturePanickedSt
missedSt
)
type funcStatus uint32
// A method value can't reach its own Method structure.
type methodType struct {
reflect.Value
Info reflect.Method
}
func newMethod(receiver reflect.Value, i int) *methodType {
return &methodType{receiver.Method(i), receiver.Type().Method(i)}
}
func (method *methodType) PC() uintptr {
return method.Info.Func.Pointer()
}
func (method *methodType) suiteName() string {
t := method.Info.Type.In(0)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t.Name()
}
func (method *methodType) String() string {
return method.suiteName() + "." + method.Info.Name
}
func (method *methodType) matches(re *regexp.Regexp) bool {
return (re.MatchString(method.Info.Name) ||
re.MatchString(method.suiteName()) ||
re.MatchString(method.String()))
}
type C struct {
method *methodType
kind funcKind
testName string
_status funcStatus
logb *logger
logw io.Writer
done chan *C
reason string
mustFail bool
tempDir *tempDir
benchMem bool
startTime time.Time
timer
}
func (c *C) status() funcStatus {
return funcStatus(atomic.LoadUint32((*uint32)(&c._status)))
}
func (c *C) setStatus(s funcStatus) {
atomic.StoreUint32((*uint32)(&c._status), uint32(s))
}
func (c *C) stopNow() {
runtime.Goexit()
}
// logger is a concurrency safe byte.Buffer
type logger struct {
sync.Mutex
writer bytes.Buffer
}
func (l *logger) Write(buf []byte) (int, error) {
l.Lock()
defer l.Unlock()
return l.writer.Write(buf)
}
func (l *logger) WriteTo(w io.Writer) (int64, error) {
l.Lock()
defer l.Unlock()
return l.writer.WriteTo(w)
}
func (l *logger) String() string {
l.Lock()
defer l.Unlock()
return l.writer.String()
}
// -----------------------------------------------------------------------
// Handling of temporary files and directories.
type tempDir struct {
sync.Mutex
path string
counter int
}
func (td *tempDir) newPath() string {
td.Lock()
defer td.Unlock()
if td.path == "" {
var err error
for i := 0; i != 100; i++ {
path := fmt.Sprintf("%s%ccheck-%d", os.TempDir(), os.PathSeparator, rand.Int())
if err = os.Mkdir(path, 0700); err == nil {
td.path = path
break
}
}
if td.path == "" {
panic("Couldn't create temporary directory: " + err.Error())
}
}
result := filepath.Join(td.path, strconv.Itoa(td.counter))
td.counter += 1
return result
}
func (td *tempDir) removeAll() {
td.Lock()
defer td.Unlock()
if td.path != "" {
err := os.RemoveAll(td.path)
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Error cleaning up temporaries: "+err.Error())
}
}
}
// Create a new temporary directory which is automatically removed after
// the suite finishes running.
func (c *C) MkDir() string {
path := c.tempDir.newPath()
if err := os.Mkdir(path, 0700); err != nil {
panic(fmt.Sprintf("Couldn't create temporary directory %s: %s", path, err.Error()))
}
return path
}
// -----------------------------------------------------------------------
// Low-level logging functions.
func (c *C) log(args ...interface{}) {
c.writeLog([]byte(fmt.Sprint(args...) + "\n"))
}
func (c *C) logf(format string, args ...interface{}) {
c.writeLog([]byte(fmt.Sprintf(format+"\n", args...)))
}
func (c *C) logNewLine() {
c.writeLog([]byte{'\n'})
}
func (c *C) writeLog(buf []byte) {
c.logb.Write(buf)
if c.logw != nil {
c.logw.Write(buf)
}
}
func hasStringOrError(x interface{}) (ok bool) {
_, ok = x.(fmt.Stringer)
if ok {
return
}
_, ok = x.(error)
return
}
func (c *C) logValue(label string, value interface{}) {
if label == "" {
if hasStringOrError(value) {
c.logf("... %#v (%q)", value, value)
} else {
c.logf("... %#v", value)
}
} else if value == nil {
c.logf("... %s = nil", label)
} else {
if hasStringOrError(value) {
fv := fmt.Sprintf("%#v", value)
qv := fmt.Sprintf("%q", value)
if fv != qv {
c.logf("... %s %s = %s (%s)", label, reflect.TypeOf(value), fv, qv)
return
}
}
if s, ok := value.(string); ok && isMultiLine(s) {
c.logf(`... %s %s = "" +`, label, reflect.TypeOf(value))
c.logMultiLine(s)
} else {
c.logf("... %s %s = %#v", label, reflect.TypeOf(value), value)
}
}
}
func (c *C) logMultiLine(s string) {
b := make([]byte, 0, len(s)*2)
i := 0
n := len(s)
for i < n {
j := i + 1
for j < n && s[j-1] != '\n' {
j++
}
b = append(b, "... "...)
b = strconv.AppendQuote(b, s[i:j])
if j < n {
b = append(b, " +"...)
}
b = append(b, '\n')
i = j
}
c.writeLog(b)
}
func isMultiLine(s string) bool {
for i := 0; i+1 < len(s); i++ {
if s[i] == '\n' {
return true
}
}
return false
}
func (c *C) logString(issue string) {
c.log("... ", issue)
}
func (c *C) logCaller(skip int) {
// This is a bit heavier than it ought to be.
skip += 1 // Our own frame.
pc, callerFile, callerLine, ok := runtime.Caller(skip)
if !ok {
return
}
var testFile string
var testLine int
testFunc := runtime.FuncForPC(c.method.PC())
if runtime.FuncForPC(pc) != testFunc {
for {
skip += 1
if pc, file, line, ok := runtime.Caller(skip); ok {
// Note that the test line may be different on
// distinct calls for the same test. Showing
// the "internal" line is helpful when debugging.
if runtime.FuncForPC(pc) == testFunc {
testFile, testLine = file, line
break
}
} else {
break
}
}
}
if testFile != "" && (testFile != callerFile || testLine != callerLine) {
c.logCode(testFile, testLine)
}
c.logCode(callerFile, callerLine)
}
func (c *C) logCode(path string, line int) {
c.logf("%s:%d:", nicePath(path), line)
code, err := printLine(path, line)
if code == "" {
code = "..." // XXX Open the file and take the raw line.
if err != nil {
code += err.Error()
}
}
c.log(indent(code, " "))
}
var valueGo = filepath.Join("reflect", "value.go")
var asmGo = filepath.Join("runtime", "asm_")
func (c *C) logPanic(skip int, value interface{}) {
skip++ // Our own frame.
initialSkip := skip
for ; ; skip++ {
if pc, file, line, ok := runtime.Caller(skip); ok {
if skip == initialSkip {
c.logf("... Panic: %s (PC=0x%X)\n", value, pc)
}
name := niceFuncName(pc)
path := nicePath(file)
if strings.Contains(path, "/gopkg.in/check.v") {
continue
}
if name == "Value.call" && strings.HasSuffix(path, valueGo) {
continue
}
if (name == "call16" || name == "call32") && strings.Contains(path, asmGo) {
continue
}
c.logf("%s:%d\n in %s", nicePath(file), line, name)
} else {
break
}
}
}
func (c *C) logSoftPanic(issue string) {
c.log("... Panic: ", issue)
}
func (c *C) logArgPanic(method *methodType, expectedType string) {
c.logf("... Panic: %s argument should be %s",
niceFuncName(method.PC()), expectedType)
}
// -----------------------------------------------------------------------
// Some simple formatting helpers.
var initWD, initWDErr = os.Getwd()
func init() {
if initWDErr == nil {
initWD = strings.Replace(initWD, "\\", "/", -1) + "/"
}
}
func nicePath(path string) string {
if initWDErr == nil {
if strings.HasPrefix(path, initWD) {
return path[len(initWD):]
}
}
return path
}
func niceFuncPath(pc uintptr) string {
function := runtime.FuncForPC(pc)
if function != nil {
filename, line := function.FileLine(pc)
return fmt.Sprintf("%s:%d", nicePath(filename), line)
}
return "<unknown path>"
}
func niceFuncName(pc uintptr) string {
function := runtime.FuncForPC(pc)
if function != nil {
name := path.Base(function.Name())
if i := strings.Index(name, "."); i > 0 {
name = name[i+1:]
}
if strings.HasPrefix(name, "(*") {
if i := strings.Index(name, ")"); i > 0 {
name = name[2:i] + name[i+1:]
}
}
if i := strings.LastIndex(name, ".*"); i != -1 {
name = name[:i] + "." + name[i+2:]
}
if i := strings.LastIndex(name, "·"); i != -1 {
name = name[:i] + "." + name[i+2:]
}
return name
}
return "<unknown function>"
}
// -----------------------------------------------------------------------
// Result tracker to aggregate call results.
type Result struct {
Succeeded int
Failed int
Skipped int
Panicked int
FixturePanicked int
ExpectedFailures int
Missed int // Not even tried to run, related to a panic in the fixture.
RunError error // Houston, we've got a problem.
WorkDir string // If KeepWorkDir is true
}
type resultTracker struct {
result Result
_lastWasProblem bool
_waiting int
_missed int
_expectChan chan *C
_doneChan chan *C
_stopChan chan bool
}
func newResultTracker() *resultTracker {
return &resultTracker{_expectChan: make(chan *C), // Synchronous
_doneChan: make(chan *C, 32), // Asynchronous
_stopChan: make(chan bool)} // Synchronous
}
func (tracker *resultTracker) start() {
go tracker._loopRoutine()
}
func (tracker *resultTracker) waitAndStop() {
<-tracker._stopChan
}
func (tracker *resultTracker) expectCall(c *C) {
tracker._expectChan <- c
}
func (tracker *resultTracker) callDone(c *C) {
tracker._doneChan <- c
}
func (tracker *resultTracker) _loopRoutine() {
for {
var c *C
if tracker._waiting > 0 {
// Calls still running. Can't stop.
select {
// XXX Reindent this (not now to make diff clear)
case c = <-tracker._expectChan:
tracker._waiting += 1
case c = <-tracker._doneChan:
tracker._waiting -= 1
switch c.status() {
case succeededSt:
if c.kind == testKd {
if c.mustFail {
tracker.result.ExpectedFailures++
} else {
tracker.result.Succeeded++
}
}
case failedSt:
tracker.result.Failed++
case panickedSt:
if c.kind == fixtureKd {
tracker.result.FixturePanicked++
} else {
tracker.result.Panicked++
}
case fixturePanickedSt:
// Track it as missed, since the panic
// was on the fixture, not on the test.
tracker.result.Missed++
case missedSt:
tracker.result.Missed++
case skippedSt:
if c.kind == testKd {
tracker.result.Skipped++
}
}
}
} else {
// No calls. Can stop, but no done calls here.
select {
case tracker._stopChan <- true:
return
case c = <-tracker._expectChan:
tracker._waiting += 1
case c = <-tracker._doneChan:
panic("Tracker got an unexpected done call.")
}
}
}
}
// -----------------------------------------------------------------------
// The underlying suite runner.
type suiteRunner struct {
suite interface{}
setUpSuite, tearDownSuite *methodType
setUpTest, tearDownTest *methodType
tests []*methodType
tracker *resultTracker
tempDir *tempDir
keepDir bool
output *outputWriter
reportedProblemLast bool
benchTime time.Duration
benchMem bool
}
type RunConf struct {
Output io.Writer
Stream bool
Verbose bool
Filter string
Benchmark bool
BenchmarkTime time.Duration // Defaults to 1 second
BenchmarkMem bool
KeepWorkDir bool
}
// Create a new suiteRunner able to run all methods in the given suite.
func newSuiteRunner(suite interface{}, runConf *RunConf) *suiteRunner {
var conf RunConf
if runConf != nil {
conf = *runConf
}
if conf.Output == nil {
conf.Output = os.Stdout
}
if conf.Benchmark {
conf.Verbose = true
}
suiteType := reflect.TypeOf(suite)
suiteNumMethods := suiteType.NumMethod()
suiteValue := reflect.ValueOf(suite)
runner := &suiteRunner{
suite: suite,
output: newOutputWriter(conf.Output, conf.Stream, conf.Verbose),
tracker: newResultTracker(),
benchTime: conf.BenchmarkTime,
benchMem: conf.BenchmarkMem,
tempDir: &tempDir{},
keepDir: conf.KeepWorkDir,
tests: make([]*methodType, 0, suiteNumMethods),
}
if runner.benchTime == 0 {
runner.benchTime = 1 * time.Second
}
var filterRegexp *regexp.Regexp
if conf.Filter != "" {
if regexp, err := regexp.Compile(conf.Filter); err != nil {
msg := "Bad filter expression: " + err.Error()
runner.tracker.result.RunError = errors.New(msg)
return runner
} else {
filterRegexp = regexp
}
}
for i := 0; i != suiteNumMethods; i++ {
method := newMethod(suiteValue, i)
switch method.Info.Name {
case "SetUpSuite":
runner.setUpSuite = method
case "TearDownSuite":
runner.tearDownSuite = method
case "SetUpTest":
runner.setUpTest = method
case "TearDownTest":
runner.tearDownTest = method
default:
prefix := "Test"
if conf.Benchmark {
prefix = "Benchmark"
}
if !strings.HasPrefix(method.Info.Name, prefix) {
continue
}
if filterRegexp == nil || method.matches(filterRegexp) {
runner.tests = append(runner.tests, method)
}
}
}
return runner
}
// Run all methods in the given suite.
func (runner *suiteRunner) run() *Result {
if runner.tracker.result.RunError == nil && len(runner.tests) > 0 {
runner.tracker.start()
if runner.checkFixtureArgs() {
c := runner.runFixture(runner.setUpSuite, "", nil)
if c == nil || c.status() == succeededSt {
for i := 0; i != len(runner.tests); i++ {
c := runner.runTest(runner.tests[i])
if c.status() == fixturePanickedSt {
runner.skipTests(missedSt, runner.tests[i+1:])
break
}
}
} else if c != nil && c.status() == skippedSt {
runner.skipTests(skippedSt, runner.tests)
} else {
runner.skipTests(missedSt, runner.tests)
}
runner.runFixture(runner.tearDownSuite, "", nil)
} else {
runner.skipTests(missedSt, runner.tests)
}
runner.tracker.waitAndStop()
if runner.keepDir {
runner.tracker.result.WorkDir = runner.tempDir.path
} else {
runner.tempDir.removeAll()
}
}
return &runner.tracker.result
}
// Create a call object with the given suite method, and fork a
// goroutine with the provided dispatcher for running it.
func (runner *suiteRunner) forkCall(method *methodType, kind funcKind, testName string, logb *logger, dispatcher func(c *C)) *C {
var logw io.Writer
if runner.output.Stream {
logw = runner.output
}
if logb == nil {
logb = new(logger)
}
c := &C{
method: method,
kind: kind,
testName: testName,
logb: logb,
logw: logw,
tempDir: runner.tempDir,
done: make(chan *C, 1),
timer: timer{benchTime: runner.benchTime},
startTime: time.Now(),
benchMem: runner.benchMem,
}
runner.tracker.expectCall(c)
go (func() {
runner.reportCallStarted(c)
defer runner.callDone(c)
dispatcher(c)
})()
return c
}
// Same as forkCall(), but wait for call to finish before returning.
func (runner *suiteRunner) runFunc(method *methodType, kind funcKind, testName string, logb *logger, dispatcher func(c *C)) *C {
c := runner.forkCall(method, kind, testName, logb, dispatcher)
<-c.done
return c
}
// Handle a finished call. If there were any panics, update the call status
// accordingly. Then, mark the call as done and report to the tracker.
func (runner *suiteRunner) callDone(c *C) {
value := recover()
if value != nil {
switch v := value.(type) {
case *fixturePanic:
if v.status == skippedSt {
c.setStatus(skippedSt)
} else {
c.logSoftPanic("Fixture has panicked (see related PANIC)")
c.setStatus(fixturePanickedSt)
}
default:
c.logPanic(1, value)
c.setStatus(panickedSt)
}
}
if c.mustFail {
switch c.status() {
case failedSt:
c.setStatus(succeededSt)
case succeededSt:
c.setStatus(failedSt)
c.logString("Error: Test succeeded, but was expected to fail")
c.logString("Reason: " + c.reason)
}
}
runner.reportCallDone(c)
c.done <- c
}
// Runs a fixture call synchronously. The fixture will still be run in a
// goroutine like all suite methods, but this method will not return
// while the fixture goroutine is not done, because the fixture must be
// run in a desired order.
func (runner *suiteRunner) runFixture(method *methodType, testName string, logb *logger) *C {
if method != nil {
c := runner.runFunc(method, fixtureKd, testName, logb, func(c *C) {
c.ResetTimer()
c.StartTimer()
defer c.StopTimer()
c.method.Call([]reflect.Value{reflect.ValueOf(c)})
})
return c
}
return nil
}
// Run the fixture method with runFixture(), but panic with a fixturePanic{}
// in case the fixture method panics. This makes it easier to track the
// fixture panic together with other call panics within forkTest().
func (runner *suiteRunner) runFixtureWithPanic(method *methodType, testName string, logb *logger, skipped *bool) *C {
if skipped != nil && *skipped {
return nil
}
c := runner.runFixture(method, testName, logb)
if c != nil && c.status() != succeededSt {
if skipped != nil {
*skipped = c.status() == skippedSt
}
panic(&fixturePanic{c.status(), method})
}
return c
}
type fixturePanic struct {
status funcStatus
method *methodType
}
// Run the suite test method, together with the test-specific fixture,
// asynchronously.
func (runner *suiteRunner) forkTest(method *methodType) *C {
testName := method.String()
return runner.forkCall(method, testKd, testName, nil, func(c *C) {
var skipped bool
defer runner.runFixtureWithPanic(runner.tearDownTest, testName, nil, &skipped)
defer c.StopTimer()
benchN := 1
for {
runner.runFixtureWithPanic(runner.setUpTest, testName, c.logb, &skipped)
mt := c.method.Type()
if mt.NumIn() != 1 || mt.In(0) != reflect.TypeOf(c) {
// Rather than a plain panic, provide a more helpful message when
// the argument type is incorrect.
c.setStatus(panickedSt)
c.logArgPanic(c.method, "*check.C")
return
}
if strings.HasPrefix(c.method.Info.Name, "Test") {
c.ResetTimer()
c.StartTimer()
c.method.Call([]reflect.Value{reflect.ValueOf(c)})
return
}
if !strings.HasPrefix(c.method.Info.Name, "Benchmark") {
panic("unexpected method prefix: " + c.method.Info.Name)
}
runtime.GC()
c.N = benchN
c.ResetTimer()
c.StartTimer()
c.method.Call([]reflect.Value{reflect.ValueOf(c)})
c.StopTimer()
if c.status() != succeededSt || c.duration >= c.benchTime || benchN >= 1e9 {
return
}
perOpN := int(1e9)
if c.nsPerOp() != 0 {
perOpN = int(c.benchTime.Nanoseconds() / c.nsPerOp())
}
// Logic taken from the stock testing package:
// - Run more iterations than we think we'll need for a second (1.5x).
// - Don't grow too fast in case we had timing errors previously.
// - Be sure to run at least one more than last time.
benchN = max(min(perOpN+perOpN/2, 100*benchN), benchN+1)
benchN = roundUp(benchN)
skipped = true // Don't run the deferred one if this panics.
runner.runFixtureWithPanic(runner.tearDownTest, testName, nil, nil)
skipped = false
}
})
}
// Same as forkTest(), but wait for the test to finish before returning.
func (runner *suiteRunner) runTest(method *methodType) *C {
c := runner.forkTest(method)
<-c.done
return c
}
// Helper to mark tests as skipped or missed. A bit heavy for what
// it does, but it enables homogeneous handling of tracking, including
// nice verbose output.
func (runner *suiteRunner) skipTests(status funcStatus, methods []*methodType) {
for _, method := range methods {
runner.runFunc(method, testKd, "", nil, func(c *C) {
c.setStatus(status)
})
}
}
// Verify if the fixture arguments are *check.C. In case of errors,
// log the error as a panic in the fixture method call, and return false.
func (runner *suiteRunner) checkFixtureArgs() bool {
succeeded := true
argType := reflect.TypeOf(&C{})
for _, method := range []*methodType{runner.setUpSuite, runner.tearDownSuite, runner.setUpTest, runner.tearDownTest} {
if method != nil {
mt := method.Type()
if mt.NumIn() != 1 || mt.In(0) != argType {
succeeded = false
runner.runFunc(method, fixtureKd, "", nil, func(c *C) {
c.logArgPanic(method, "*check.C")
c.setStatus(panickedSt)
})
}
}
}
return succeeded
}
func (runner *suiteRunner) reportCallStarted(c *C) {
runner.output.WriteCallStarted("START", c)
}
func (runner *suiteRunner) reportCallDone(c *C) {
runner.tracker.callDone(c)
switch c.status() {
case succeededSt:
if c.mustFail {
runner.output.WriteCallSuccess("FAIL EXPECTED", c)
} else {
runner.output.WriteCallSuccess("PASS", c)
}
case skippedSt:
runner.output.WriteCallSuccess("SKIP", c)
case failedSt:
runner.output.WriteCallProblem("FAIL", c)
case panickedSt:
runner.output.WriteCallProblem("PANIC", c)
case fixturePanickedSt:
// That's a testKd call reporting that its fixture
// has panicked. The fixture call which caused the
// panic itself was tracked above. We'll report to
// aid debugging.
runner.output.WriteCallProblem("PANIC", c)
case missedSt:
runner.output.WriteCallSuccess("MISS", c)
}
}
// -----------------------------------------------------------------------
// Output writer manages atomic output writing according to settings.
type outputWriter struct {
m sync.Mutex
writer io.Writer
wroteCallProblemLast bool
Stream bool
Verbose bool
}
func newOutputWriter(writer io.Writer, stream, verbose bool) *outputWriter {
return &outputWriter{writer: writer, Stream: stream, Verbose: verbose}
}
func (ow *outputWriter) Write(content []byte) (n int, err error) {
ow.m.Lock()
n, err = ow.writer.Write(content)
ow.m.Unlock()
return
}
func (ow *outputWriter) WriteCallStarted(label string, c *C) {
if ow.Stream {
header := renderCallHeader(label, c, "", "\n")
ow.m.Lock()
ow.writer.Write([]byte(header))
ow.m.Unlock()
}
}
func (ow *outputWriter) WriteCallProblem(label string, c *C) {
var prefix string
if !ow.Stream {
prefix = "\n-----------------------------------" +
"-----------------------------------\n"
}
header := renderCallHeader(label, c, prefix, "\n\n")
ow.m.Lock()
ow.wroteCallProblemLast = true
ow.writer.Write([]byte(header))
if !ow.Stream {
c.logb.WriteTo(ow.writer)
}
ow.m.Unlock()
}
func (ow *outputWriter) WriteCallSuccess(label string, c *C) {
if ow.Stream || (ow.Verbose && c.kind == testKd) {
// TODO Use a buffer here.
var suffix string
if c.reason != "" {
suffix = " (" + c.reason + ")"
}
if c.status() == succeededSt {
suffix += "\t" + c.timerString()
}
suffix += "\n"
if ow.Stream {
suffix += "\n"
}
header := renderCallHeader(label, c, "", suffix)
ow.m.Lock()
// Resist temptation of using line as prefix above due to race.
if !ow.Stream && ow.wroteCallProblemLast {
header = "\n-----------------------------------" +
"-----------------------------------\n" +
header
}
ow.wroteCallProblemLast = false
ow.writer.Write([]byte(header))
ow.m.Unlock()
}
}
func renderCallHeader(label string, c *C, prefix, suffix string) string {
pc := c.method.PC()
return fmt.Sprintf("%s%s: %s: %s%s", prefix, label, niceFuncPath(pc),
niceFuncName(pc), suffix)
}

View File

@ -1,458 +0,0 @@
package check
import (
"fmt"
"reflect"
"regexp"
)
// -----------------------------------------------------------------------
// CommentInterface and Commentf helper, to attach extra information to checks.
type comment struct {
format string
args []interface{}
}
// Commentf returns an infomational value to use with Assert or Check calls.
// If the checker test fails, the provided arguments will be passed to
// fmt.Sprintf, and will be presented next to the logged failure.
//
// For example:
//
// c.Assert(v, Equals, 42, Commentf("Iteration #%d failed.", i))
//
// Note that if the comment is constant, a better option is to
// simply use a normal comment right above or next to the line, as
// it will also get printed with any errors:
//
// c.Assert(l, Equals, 8192) // Ensure buffer size is correct (bug #123)
//
func Commentf(format string, args ...interface{}) CommentInterface {
return &comment{format, args}
}
// CommentInterface must be implemented by types that attach extra
// information to failed checks. See the Commentf function for details.
type CommentInterface interface {
CheckCommentString() string
}
func (c *comment) CheckCommentString() string {
return fmt.Sprintf(c.format, c.args...)
}
// -----------------------------------------------------------------------
// The Checker interface.
// The Checker interface must be provided by checkers used with
// the Assert and Check verification methods.
type Checker interface {
Info() *CheckerInfo
Check(params []interface{}, names []string) (result bool, error string)
}
// See the Checker interface.
type CheckerInfo struct {
Name string
Params []string
}
func (info *CheckerInfo) Info() *CheckerInfo {
return info
}
// -----------------------------------------------------------------------
// Not checker logic inverter.
// The Not checker inverts the logic of the provided checker. The
// resulting checker will succeed where the original one failed, and
// vice-versa.
//
// For example:
//
// c.Assert(a, Not(Equals), b)
//
func Not(checker Checker) Checker {
return &notChecker{checker}
}
type notChecker struct {
sub Checker
}
func (checker *notChecker) Info() *CheckerInfo {
info := *checker.sub.Info()
info.Name = "Not(" + info.Name + ")"
return &info
}
func (checker *notChecker) Check(params []interface{}, names []string) (result bool, error string) {
result, error = checker.sub.Check(params, names)
result = !result
return
}
// -----------------------------------------------------------------------
// IsNil checker.
type isNilChecker struct {
*CheckerInfo
}
// The IsNil checker tests whether the obtained value is nil.
//
// For example:
//
// c.Assert(err, IsNil)
//
var IsNil Checker = &isNilChecker{
&CheckerInfo{Name: "IsNil", Params: []string{"value"}},
}
func (checker *isNilChecker) Check(params []interface{}, names []string) (result bool, error string) {
return isNil(params[0]), ""
}
func isNil(obtained interface{}) (result bool) {
if obtained == nil {
result = true
} else {
switch v := reflect.ValueOf(obtained); v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return v.IsNil()
}
}
return
}
// -----------------------------------------------------------------------
// NotNil checker. Alias for Not(IsNil), since it's so common.
type notNilChecker struct {
*CheckerInfo
}
// The NotNil checker verifies that the obtained value is not nil.
//
// For example:
//
// c.Assert(iface, NotNil)
//
// This is an alias for Not(IsNil), made available since it's a
// fairly common check.
//
var NotNil Checker = &notNilChecker{
&CheckerInfo{Name: "NotNil", Params: []string{"value"}},
}
func (checker *notNilChecker) Check(params []interface{}, names []string) (result bool, error string) {
return !isNil(params[0]), ""
}
// -----------------------------------------------------------------------
// Equals checker.
type equalsChecker struct {
*CheckerInfo
}
// The Equals checker verifies that the obtained value is equal to
// the expected value, according to usual Go semantics for ==.
//
// For example:
//
// c.Assert(value, Equals, 42)
//
var Equals Checker = &equalsChecker{
&CheckerInfo{Name: "Equals", Params: []string{"obtained", "expected"}},
}
func (checker *equalsChecker) Check(params []interface{}, names []string) (result bool, error string) {
defer func() {
if v := recover(); v != nil {
result = false
error = fmt.Sprint(v)
}
}()
return params[0] == params[1], ""
}
// -----------------------------------------------------------------------
// DeepEquals checker.
type deepEqualsChecker struct {
*CheckerInfo
}
// The DeepEquals checker verifies that the obtained value is deep-equal to
// the expected value. The check will work correctly even when facing
// slices, interfaces, and values of different types (which always fail
// the test).
//
// For example:
//
// c.Assert(value, DeepEquals, 42)
// c.Assert(array, DeepEquals, []string{"hi", "there"})
//
var DeepEquals Checker = &deepEqualsChecker{
&CheckerInfo{Name: "DeepEquals", Params: []string{"obtained", "expected"}},
}
func (checker *deepEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
return reflect.DeepEqual(params[0], params[1]), ""
}
// -----------------------------------------------------------------------
// HasLen checker.
type hasLenChecker struct {
*CheckerInfo
}
// The HasLen checker verifies that the obtained value has the
// provided length. In many cases this is superior to using Equals
// in conjuction with the len function because in case the check
// fails the value itself will be printed, instead of its length,
// providing more details for figuring the problem.
//
// For example:
//
// c.Assert(list, HasLen, 5)
//
var HasLen Checker = &hasLenChecker{
&CheckerInfo{Name: "HasLen", Params: []string{"obtained", "n"}},
}
func (checker *hasLenChecker) Check(params []interface{}, names []string) (result bool, error string) {
n, ok := params[1].(int)
if !ok {
return false, "n must be an int"
}
value := reflect.ValueOf(params[0])
switch value.Kind() {
case reflect.Map, reflect.Array, reflect.Slice, reflect.Chan, reflect.String:
default:
return false, "obtained value type has no length"
}
return value.Len() == n, ""
}
// -----------------------------------------------------------------------
// ErrorMatches checker.
type errorMatchesChecker struct {
*CheckerInfo
}
// The ErrorMatches checker verifies that the error value
// is non nil and matches the regular expression provided.
//
// For example:
//
// c.Assert(err, ErrorMatches, "perm.*denied")
//
var ErrorMatches Checker = errorMatchesChecker{
&CheckerInfo{Name: "ErrorMatches", Params: []string{"value", "regex"}},
}
func (checker errorMatchesChecker) Check(params []interface{}, names []string) (result bool, errStr string) {
if params[0] == nil {
return false, "Error value is nil"
}
err, ok := params[0].(error)
if !ok {
return false, "Value is not an error"
}
params[0] = err.Error()
names[0] = "error"
return matches(params[0], params[1])
}
// -----------------------------------------------------------------------
// Matches checker.
type matchesChecker struct {
*CheckerInfo
}
// The Matches checker verifies that the string provided as the obtained
// value (or the string resulting from obtained.String()) matches the
// regular expression provided.
//
// For example:
//
// c.Assert(err, Matches, "perm.*denied")
//
var Matches Checker = &matchesChecker{
&CheckerInfo{Name: "Matches", Params: []string{"value", "regex"}},
}
func (checker *matchesChecker) Check(params []interface{}, names []string) (result bool, error string) {
return matches(params[0], params[1])
}
func matches(value, regex interface{}) (result bool, error string) {
reStr, ok := regex.(string)
if !ok {
return false, "Regex must be a string"
}
valueStr, valueIsStr := value.(string)
if !valueIsStr {
if valueWithStr, valueHasStr := value.(fmt.Stringer); valueHasStr {
valueStr, valueIsStr = valueWithStr.String(), true
}
}
if valueIsStr {
matches, err := regexp.MatchString("^"+reStr+"$", valueStr)
if err != nil {
return false, "Can't compile regex: " + err.Error()
}
return matches, ""
}
return false, "Obtained value is not a string and has no .String()"
}
// -----------------------------------------------------------------------
// Panics checker.
type panicsChecker struct {
*CheckerInfo
}
// The Panics checker verifies that calling the provided zero-argument
// function will cause a panic which is deep-equal to the provided value.
//
// For example:
//
// c.Assert(func() { f(1, 2) }, Panics, &SomeErrorType{"BOOM"}).
//
//
var Panics Checker = &panicsChecker{
&CheckerInfo{Name: "Panics", Params: []string{"function", "expected"}},
}
func (checker *panicsChecker) Check(params []interface{}, names []string) (result bool, error string) {
f := reflect.ValueOf(params[0])
if f.Kind() != reflect.Func || f.Type().NumIn() != 0 {
return false, "Function must take zero arguments"
}
defer func() {
// If the function has not panicked, then don't do the check.
if error != "" {
return
}
params[0] = recover()
names[0] = "panic"
result = reflect.DeepEqual(params[0], params[1])
}()
f.Call(nil)
return false, "Function has not panicked"
}
type panicMatchesChecker struct {
*CheckerInfo
}
// The PanicMatches checker verifies that calling the provided zero-argument
// function will cause a panic with an error value matching
// the regular expression provided.
//
// For example:
//
// c.Assert(func() { f(1, 2) }, PanicMatches, `open.*: no such file or directory`).
//
//
var PanicMatches Checker = &panicMatchesChecker{
&CheckerInfo{Name: "PanicMatches", Params: []string{"function", "expected"}},
}
func (checker *panicMatchesChecker) Check(params []interface{}, names []string) (result bool, errmsg string) {
f := reflect.ValueOf(params[0])
if f.Kind() != reflect.Func || f.Type().NumIn() != 0 {
return false, "Function must take zero arguments"
}
defer func() {
// If the function has not panicked, then don't do the check.
if errmsg != "" {
return
}
obtained := recover()
names[0] = "panic"
if e, ok := obtained.(error); ok {
params[0] = e.Error()
} else if _, ok := obtained.(string); ok {
params[0] = obtained
} else {
errmsg = "Panic value is not a string or an error"
return
}
result, errmsg = matches(params[0], params[1])
}()
f.Call(nil)
return false, "Function has not panicked"
}
// -----------------------------------------------------------------------
// FitsTypeOf checker.
type fitsTypeChecker struct {
*CheckerInfo
}
// The FitsTypeOf checker verifies that the obtained value is
// assignable to a variable with the same type as the provided
// sample value.
//
// For example:
//
// c.Assert(value, FitsTypeOf, int64(0))
// c.Assert(value, FitsTypeOf, os.Error(nil))
//
var FitsTypeOf Checker = &fitsTypeChecker{
&CheckerInfo{Name: "FitsTypeOf", Params: []string{"obtained", "sample"}},
}
func (checker *fitsTypeChecker) Check(params []interface{}, names []string) (result bool, error string) {
obtained := reflect.ValueOf(params[0])
sample := reflect.ValueOf(params[1])
if !obtained.IsValid() {
return false, ""
}
if !sample.IsValid() {
return false, "Invalid sample value"
}
return obtained.Type().AssignableTo(sample.Type()), ""
}
// -----------------------------------------------------------------------
// Implements checker.
type implementsChecker struct {
*CheckerInfo
}
// The Implements checker verifies that the obtained value
// implements the interface specified via a pointer to an interface
// variable.
//
// For example:
//
// var e os.Error
// c.Assert(err, Implements, &e)
//
var Implements Checker = &implementsChecker{
&CheckerInfo{Name: "Implements", Params: []string{"obtained", "ifaceptr"}},
}
func (checker *implementsChecker) Check(params []interface{}, names []string) (result bool, error string) {
obtained := reflect.ValueOf(params[0])
ifaceptr := reflect.ValueOf(params[1])
if !obtained.IsValid() {
return false, ""
}
if !ifaceptr.IsValid() || ifaceptr.Kind() != reflect.Ptr || ifaceptr.Elem().Kind() != reflect.Interface {
return false, "ifaceptr should be a pointer to an interface variable"
}
return obtained.Type().Implements(ifaceptr.Elem().Type()), ""
}

View File

@ -1,131 +0,0 @@
// Extensions to the go-check unittest framework.
//
// NOTE: see https://github.com/go-check/check/pull/6 for reasons why these
// checkers live here.
package check
import (
"bytes"
"reflect"
)
// -----------------------------------------------------------------------
// IsTrue / IsFalse checker.
type isBoolValueChecker struct {
*CheckerInfo
expected bool
}
func (checker *isBoolValueChecker) Check(
params []interface{},
names []string) (
result bool,
error string) {
obtained, ok := params[0].(bool)
if !ok {
return false, "Argument to " + checker.Name + " must be bool"
}
return obtained == checker.expected, ""
}
// The IsTrue checker verifies that the obtained value is true.
//
// For example:
//
// c.Assert(value, IsTrue)
//
var IsTrue Checker = &isBoolValueChecker{
&CheckerInfo{Name: "IsTrue", Params: []string{"obtained"}},
true,
}
// The IsFalse checker verifies that the obtained value is false.
//
// For example:
//
// c.Assert(value, IsFalse)
//
var IsFalse Checker = &isBoolValueChecker{
&CheckerInfo{Name: "IsFalse", Params: []string{"obtained"}},
false,
}
// -----------------------------------------------------------------------
// BytesEquals checker.
type bytesEquals struct{}
func (b *bytesEquals) Check(params []interface{}, names []string) (bool, string) {
if len(params) != 2 {
return false, "BytesEqual takes 2 bytestring arguments"
}
b1, ok1 := params[0].([]byte)
b2, ok2 := params[1].([]byte)
if !(ok1 && ok2) {
return false, "Arguments to BytesEqual must both be bytestrings"
}
return bytes.Equal(b1, b2), ""
}
func (b *bytesEquals) Info() *CheckerInfo {
return &CheckerInfo{
Name: "BytesEquals",
Params: []string{"bytes_one", "bytes_two"},
}
}
// BytesEquals checker compares two bytes sequence using bytes.Equal.
//
// For example:
//
// c.Assert(b, BytesEquals, []byte("bar"))
//
// Main difference between DeepEquals and BytesEquals is that BytesEquals treats
// `nil` as empty byte sequence while DeepEquals doesn't.
//
// c.Assert(nil, BytesEquals, []byte("")) // succeeds
// c.Assert(nil, DeepEquals, []byte("")) // fails
var BytesEquals = &bytesEquals{}
// -----------------------------------------------------------------------
// HasKey checker.
type hasKey struct{}
func (h *hasKey) Check(params []interface{}, names []string) (bool, string) {
if len(params) != 2 {
return false, "HasKey takes 2 arguments: a map and a key"
}
mapValue := reflect.ValueOf(params[0])
if mapValue.Kind() != reflect.Map {
return false, "First argument to HasKey must be a map"
}
keyValue := reflect.ValueOf(params[1])
if !keyValue.Type().AssignableTo(mapValue.Type().Key()) {
return false, "Second argument must be assignable to the map key type"
}
return mapValue.MapIndex(keyValue).IsValid(), ""
}
func (h *hasKey) Info() *CheckerInfo {
return &CheckerInfo{
Name: "HasKey",
Params: []string{"obtained", "key"},
}
}
// The HasKey checker verifies that the obtained map contains the given key.
//
// For example:
//
// c.Assert(myMap, HasKey, "foo")
//
var HasKey = &hasKey{}

View File

@ -1,161 +0,0 @@
package check
import (
"bytes"
"fmt"
"reflect"
"time"
)
type compareFunc func(v1 interface{}, v2 interface{}) (bool, error)
type valueCompare struct {
Name string
Func compareFunc
Operator string
}
// v1 and v2 must have the same type
// return >0 if v1 > v2
// return 0 if v1 = v2
// return <0 if v1 < v2
// now we only support int, uint, float64, string and []byte comparison
func compare(v1 interface{}, v2 interface{}) (int, error) {
value1 := reflect.ValueOf(v1)
value2 := reflect.ValueOf(v2)
switch v1.(type) {
case int, int8, int16, int32, int64:
a1 := value1.Int()
a2 := value2.Int()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case uint, uint8, uint16, uint32, uint64:
a1 := value1.Uint()
a2 := value2.Uint()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case float32, float64:
a1 := value1.Float()
a2 := value2.Float()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case string:
a1 := value1.String()
a2 := value2.String()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case []byte:
a1 := value1.Bytes()
a2 := value2.Bytes()
return bytes.Compare(a1, a2), nil
case time.Time:
a1 := v1.(time.Time)
a2 := v2.(time.Time)
if a1.After(a2) {
return 1, nil
} else if a1.Equal(a2) {
return 0, nil
}
return -1, nil
case time.Duration:
a1 := v1.(time.Duration)
a2 := v2.(time.Duration)
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
default:
return 0, fmt.Errorf("type %T is not supported now", v1)
}
}
func less(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n < 0, nil
}
func lessEqual(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n <= 0, nil
}
func greater(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n > 0, nil
}
func greaterEqual(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n >= 0, nil
}
func (v *valueCompare) Check(params []interface{}, names []string) (bool, string) {
if len(params) != 2 {
return false, fmt.Sprintf("%s needs 2 arguments", v.Name)
}
v1 := params[0]
v2 := params[1]
v1Type := reflect.TypeOf(v1)
v2Type := reflect.TypeOf(v2)
if v1Type.Kind() != v2Type.Kind() {
return false, fmt.Sprintf("%s needs two same type, but %s != %s", v.Name, v1Type.Kind(), v2Type.Kind())
}
b, err := v.Func(v1, v2)
if err != nil {
return false, fmt.Sprintf("%s check err %v", v.Name, err)
}
return b, ""
}
func (v *valueCompare) Info() *CheckerInfo {
return &CheckerInfo{
Name: v.Name,
Params: []string{"compare_one", "compare_two"},
}
}
var Less = &valueCompare{Name: "Less", Func: less, Operator: "<"}
var LessEqual = &valueCompare{Name: "LessEqual", Func: lessEqual, Operator: "<="}
var Greater = &valueCompare{Name: "Greater", Func: greater, Operator: ">"}
var GreaterEqual = &valueCompare{Name: "GreaterEqual", Func: greaterEqual, Operator: ">="}

View File

@ -1,231 +0,0 @@
package check
import (
"fmt"
"strings"
"time"
)
// TestName returns the current test name in the form "SuiteName.TestName"
func (c *C) TestName() string {
return c.testName
}
// -----------------------------------------------------------------------
// Basic succeeding/failing logic.
// Failed returns whether the currently running test has already failed.
func (c *C) Failed() bool {
return c.status() == failedSt
}
// Fail marks the currently running test as failed.
//
// Something ought to have been previously logged so the developer can tell
// what went wrong. The higher level helper functions will fail the test
// and do the logging properly.
func (c *C) Fail() {
c.setStatus(failedSt)
}
// FailNow marks the currently running test as failed and stops running it.
// Something ought to have been previously logged so the developer can tell
// what went wrong. The higher level helper functions will fail the test
// and do the logging properly.
func (c *C) FailNow() {
c.Fail()
c.stopNow()
}
// Succeed marks the currently running test as succeeded, undoing any
// previous failures.
func (c *C) Succeed() {
c.setStatus(succeededSt)
}
// SucceedNow marks the currently running test as succeeded, undoing any
// previous failures, and stops running the test.
func (c *C) SucceedNow() {
c.Succeed()
c.stopNow()
}
// ExpectFailure informs that the running test is knowingly broken for
// the provided reason. If the test does not fail, an error will be reported
// to raise attention to this fact. This method is useful to temporarily
// disable tests which cover well known problems until a better time to
// fix the problem is found, without forgetting about the fact that a
// failure still exists.
func (c *C) ExpectFailure(reason string) {
if reason == "" {
panic("Missing reason why the test is expected to fail")
}
c.mustFail = true
c.reason = reason
}
// Skip skips the running test for the provided reason. If run from within
// SetUpTest, the individual test being set up will be skipped, and if run
// from within SetUpSuite, the whole suite is skipped.
func (c *C) Skip(reason string) {
if reason == "" {
panic("Missing reason why the test is being skipped")
}
c.reason = reason
c.setStatus(skippedSt)
c.stopNow()
}
// -----------------------------------------------------------------------
// Basic logging.
// GetTestLog returns the current test error output.
func (c *C) GetTestLog() string {
return c.logb.String()
}
// Log logs some information into the test error output.
// The provided arguments are assembled together into a string with fmt.Sprint.
func (c *C) Log(args ...interface{}) {
c.log(args...)
}
// Log logs some information into the test error output.
// The provided arguments are assembled together into a string with fmt.Sprintf.
func (c *C) Logf(format string, args ...interface{}) {
c.logf(format, args...)
}
// Output enables *C to be used as a logger in functions that require only
// the minimum interface of *log.Logger.
func (c *C) Output(calldepth int, s string) error {
d := time.Now().Sub(c.startTime)
msec := d / time.Millisecond
sec := d / time.Second
min := d / time.Minute
c.Logf("[LOG] %d:%02d.%03d %s", min, sec%60, msec%1000, s)
return nil
}
// Error logs an error into the test error output and marks the test as failed.
// The provided arguments are assembled together into a string with fmt.Sprint.
func (c *C) Error(args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprint("Error: ", fmt.Sprint(args...)))
c.logNewLine()
c.Fail()
}
// Errorf logs an error into the test error output and marks the test as failed.
// The provided arguments are assembled together into a string with fmt.Sprintf.
func (c *C) Errorf(format string, args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprintf("Error: "+format, args...))
c.logNewLine()
c.Fail()
}
// Fatal logs an error into the test error output, marks the test as failed, and
// stops the test execution. The provided arguments are assembled together into
// a string with fmt.Sprint.
func (c *C) Fatal(args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprint("Error: ", fmt.Sprint(args...)))
c.logNewLine()
c.FailNow()
}
// Fatlaf logs an error into the test error output, marks the test as failed, and
// stops the test execution. The provided arguments are assembled together into
// a string with fmt.Sprintf.
func (c *C) Fatalf(format string, args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprint("Error: ", fmt.Sprintf(format, args...)))
c.logNewLine()
c.FailNow()
}
// -----------------------------------------------------------------------
// Generic checks and assertions based on checkers.
// Check verifies if the first value matches the expected value according
// to the provided checker. If they do not match, an error is logged, the
// test is marked as failed, and the test execution continues.
//
// Some checkers may not need the expected argument (e.g. IsNil).
//
// Extra arguments provided to the function are logged next to the reported
// problem when the matching fails.
func (c *C) Check(obtained interface{}, checker Checker, args ...interface{}) bool {
return c.internalCheck("Check", obtained, checker, args...)
}
// Assert ensures that the first value matches the expected value according
// to the provided checker. If they do not match, an error is logged, the
// test is marked as failed, and the test execution stops.
//
// Some checkers may not need the expected argument (e.g. IsNil).
//
// Extra arguments provided to the function are logged next to the reported
// problem when the matching fails.
func (c *C) Assert(obtained interface{}, checker Checker, args ...interface{}) {
if !c.internalCheck("Assert", obtained, checker, args...) {
c.stopNow()
}
}
func (c *C) internalCheck(funcName string, obtained interface{}, checker Checker, args ...interface{}) bool {
if checker == nil {
c.logCaller(2)
c.logString(fmt.Sprintf("%s(obtained, nil!?, ...):", funcName))
c.logString("Oops.. you've provided a nil checker!")
c.logNewLine()
c.Fail()
return false
}
// If the last argument is a bug info, extract it out.
var comment CommentInterface
if len(args) > 0 {
if c, ok := args[len(args)-1].(CommentInterface); ok {
comment = c
args = args[:len(args)-1]
}
}
params := append([]interface{}{obtained}, args...)
info := checker.Info()
if len(params) != len(info.Params) {
names := append([]string{info.Params[0], info.Name}, info.Params[1:]...)
c.logCaller(2)
c.logString(fmt.Sprintf("%s(%s):", funcName, strings.Join(names, ", ")))
c.logString(fmt.Sprintf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(names), len(params)+1))
c.logNewLine()
c.Fail()
return false
}
// Copy since it may be mutated by Check.
names := append([]string{}, info.Params...)
// Do the actual check.
result, error := checker.Check(params, names)
if !result || error != "" {
c.logCaller(2)
for i := 0; i != len(params); i++ {
c.logValue(names[i], params[i])
}
if comment != nil {
c.logString(comment.CheckCommentString())
}
if error != "" {
c.logString(error)
}
c.logNewLine()
c.Fail()
return false
}
return true
}

View File

@ -1,168 +0,0 @@
package check
import (
"bytes"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
)
func indent(s, with string) (r string) {
eol := true
for i := 0; i != len(s); i++ {
c := s[i]
switch {
case eol && c == '\n' || c == '\r':
case c == '\n' || c == '\r':
eol = true
case eol:
eol = false
s = s[:i] + with + s[i:]
i += len(with)
}
}
return s
}
func printLine(filename string, line int) (string, error) {
fset := token.NewFileSet()
file, err := os.Open(filename)
if err != nil {
return "", err
}
fnode, err := parser.ParseFile(fset, filename, file, parser.ParseComments)
if err != nil {
return "", err
}
config := &printer.Config{Mode: printer.UseSpaces, Tabwidth: 4}
lp := &linePrinter{fset: fset, fnode: fnode, line: line, config: config}
ast.Walk(lp, fnode)
result := lp.output.Bytes()
// Comments leave \n at the end.
n := len(result)
for n > 0 && result[n-1] == '\n' {
n--
}
return string(result[:n]), nil
}
type linePrinter struct {
config *printer.Config
fset *token.FileSet
fnode *ast.File
line int
output bytes.Buffer
stmt ast.Stmt
}
func (lp *linePrinter) emit() bool {
if lp.stmt != nil {
lp.trim(lp.stmt)
lp.printWithComments(lp.stmt)
lp.stmt = nil
return true
}
return false
}
func (lp *linePrinter) printWithComments(n ast.Node) {
nfirst := lp.fset.Position(n.Pos()).Line
nlast := lp.fset.Position(n.End()).Line
for _, g := range lp.fnode.Comments {
cfirst := lp.fset.Position(g.Pos()).Line
clast := lp.fset.Position(g.End()).Line
if clast == nfirst-1 && lp.fset.Position(n.Pos()).Column == lp.fset.Position(g.Pos()).Column {
for _, c := range g.List {
lp.output.WriteString(c.Text)
lp.output.WriteByte('\n')
}
}
if cfirst >= nfirst && cfirst <= nlast && n.End() <= g.List[0].Slash {
// The printer will not include the comment if it starts past
// the node itself. Trick it into printing by overlapping the
// slash with the end of the statement.
g.List[0].Slash = n.End() - 1
}
}
node := &printer.CommentedNode{n, lp.fnode.Comments}
lp.config.Fprint(&lp.output, lp.fset, node)
}
func (lp *linePrinter) Visit(n ast.Node) (w ast.Visitor) {
if n == nil {
if lp.output.Len() == 0 {
lp.emit()
}
return nil
}
first := lp.fset.Position(n.Pos()).Line
last := lp.fset.Position(n.End()).Line
if first <= lp.line && last >= lp.line {
// Print the innermost statement containing the line.
if stmt, ok := n.(ast.Stmt); ok {
if _, ok := n.(*ast.BlockStmt); !ok {
lp.stmt = stmt
}
}
if first == lp.line && lp.emit() {
return nil
}
return lp
}
return nil
}
func (lp *linePrinter) trim(n ast.Node) bool {
stmt, ok := n.(ast.Stmt)
if !ok {
return true
}
line := lp.fset.Position(n.Pos()).Line
if line != lp.line {
return false
}
switch stmt := stmt.(type) {
case *ast.IfStmt:
stmt.Body = lp.trimBlock(stmt.Body)
case *ast.SwitchStmt:
stmt.Body = lp.trimBlock(stmt.Body)
case *ast.TypeSwitchStmt:
stmt.Body = lp.trimBlock(stmt.Body)
case *ast.CaseClause:
stmt.Body = lp.trimList(stmt.Body)
case *ast.CommClause:
stmt.Body = lp.trimList(stmt.Body)
case *ast.BlockStmt:
stmt.List = lp.trimList(stmt.List)
}
return true
}
func (lp *linePrinter) trimBlock(stmt *ast.BlockStmt) *ast.BlockStmt {
if !lp.trim(stmt) {
return lp.emptyBlock(stmt)
}
stmt.Rbrace = stmt.Lbrace
return stmt
}
func (lp *linePrinter) trimList(stmts []ast.Stmt) []ast.Stmt {
for i := 0; i != len(stmts); i++ {
if !lp.trim(stmts[i]) {
stmts[i] = lp.emptyStmt(stmts[i])
break
}
}
return stmts
}
func (lp *linePrinter) emptyStmt(n ast.Node) *ast.ExprStmt {
return &ast.ExprStmt{&ast.Ellipsis{n.Pos(), nil}}
}
func (lp *linePrinter) emptyBlock(n ast.Node) *ast.BlockStmt {
p := n.Pos()
return &ast.BlockStmt{p, []ast.Stmt{lp.emptyStmt(n)}, p}
}

View File

@ -1,175 +0,0 @@
package check
import (
"bufio"
"flag"
"fmt"
"os"
"testing"
"time"
)
// -----------------------------------------------------------------------
// Test suite registry.
var allSuites []interface{}
// Suite registers the given value as a test suite to be run. Any methods
// starting with the Test prefix in the given value will be considered as
// a test method.
func Suite(suite interface{}) interface{} {
allSuites = append(allSuites, suite)
return suite
}
// -----------------------------------------------------------------------
// Public running interface.
var (
oldFilterFlag = flag.String("gocheck.f", "", "Regular expression selecting which tests and/or suites to run")
oldVerboseFlag = flag.Bool("gocheck.v", false, "Verbose mode")
oldStreamFlag = flag.Bool("gocheck.vv", false, "Super verbose mode (disables output caching)")
oldBenchFlag = flag.Bool("gocheck.b", false, "Run benchmarks")
oldBenchTime = flag.Duration("gocheck.btime", 1*time.Second, "approximate run time for each benchmark")
oldListFlag = flag.Bool("gocheck.list", false, "List the names of all tests that will be run")
oldWorkFlag = flag.Bool("gocheck.work", false, "Display and do not remove the test working directory")
newFilterFlag = flag.String("check.f", "", "Regular expression selecting which tests and/or suites to run")
newVerboseFlag = flag.Bool("check.v", false, "Verbose mode")
newStreamFlag = flag.Bool("check.vv", false, "Super verbose mode (disables output caching)")
newBenchFlag = flag.Bool("check.b", false, "Run benchmarks")
newBenchTime = flag.Duration("check.btime", 1*time.Second, "approximate run time for each benchmark")
newBenchMem = flag.Bool("check.bmem", false, "Report memory benchmarks")
newListFlag = flag.Bool("check.list", false, "List the names of all tests that will be run")
newWorkFlag = flag.Bool("check.work", false, "Display and do not remove the test working directory")
)
// TestingT runs all test suites registered with the Suite function,
// printing results to stdout, and reporting any failures back to
// the "testing" package.
func TestingT(testingT *testing.T) {
benchTime := *newBenchTime
if benchTime == 1*time.Second {
benchTime = *oldBenchTime
}
conf := &RunConf{
Filter: *oldFilterFlag + *newFilterFlag,
Verbose: *oldVerboseFlag || *newVerboseFlag,
Stream: *oldStreamFlag || *newStreamFlag,
Benchmark: *oldBenchFlag || *newBenchFlag,
BenchmarkTime: benchTime,
BenchmarkMem: *newBenchMem,
KeepWorkDir: *oldWorkFlag || *newWorkFlag,
}
if *oldListFlag || *newListFlag {
w := bufio.NewWriter(os.Stdout)
for _, name := range ListAll(conf) {
fmt.Fprintln(w, name)
}
w.Flush()
return
}
result := RunAll(conf)
println(result.String())
if !result.Passed() {
testingT.Fail()
}
}
// RunAll runs all test suites registered with the Suite function, using the
// provided run configuration.
func RunAll(runConf *RunConf) *Result {
result := Result{}
for _, suite := range allSuites {
result.Add(Run(suite, runConf))
}
return &result
}
// Run runs the provided test suite using the provided run configuration.
func Run(suite interface{}, runConf *RunConf) *Result {
runner := newSuiteRunner(suite, runConf)
return runner.run()
}
// ListAll returns the names of all the test functions registered with the
// Suite function that will be run with the provided run configuration.
func ListAll(runConf *RunConf) []string {
var names []string
for _, suite := range allSuites {
names = append(names, List(suite, runConf)...)
}
return names
}
// List returns the names of the test functions in the given
// suite that will be run with the provided run configuration.
func List(suite interface{}, runConf *RunConf) []string {
var names []string
runner := newSuiteRunner(suite, runConf)
for _, t := range runner.tests {
names = append(names, t.String())
}
return names
}
// -----------------------------------------------------------------------
// Result methods.
func (r *Result) Add(other *Result) {
r.Succeeded += other.Succeeded
r.Skipped += other.Skipped
r.Failed += other.Failed
r.Panicked += other.Panicked
r.FixturePanicked += other.FixturePanicked
r.ExpectedFailures += other.ExpectedFailures
r.Missed += other.Missed
if r.WorkDir != "" && other.WorkDir != "" {
r.WorkDir += ":" + other.WorkDir
} else if other.WorkDir != "" {
r.WorkDir = other.WorkDir
}
}
func (r *Result) Passed() bool {
return (r.Failed == 0 && r.Panicked == 0 &&
r.FixturePanicked == 0 && r.Missed == 0 &&
r.RunError == nil)
}
func (r *Result) String() string {
if r.RunError != nil {
return "ERROR: " + r.RunError.Error()
}
var value string
if r.Failed == 0 && r.Panicked == 0 && r.FixturePanicked == 0 &&
r.Missed == 0 {
value = "OK: "
} else {
value = "OOPS: "
}
value += fmt.Sprintf("%d passed", r.Succeeded)
if r.Skipped != 0 {
value += fmt.Sprintf(", %d skipped", r.Skipped)
}
if r.ExpectedFailures != 0 {
value += fmt.Sprintf(", %d expected failures", r.ExpectedFailures)
}
if r.Failed != 0 {
value += fmt.Sprintf(", %d FAILED", r.Failed)
}
if r.Panicked != 0 {
value += fmt.Sprintf(", %d PANICKED", r.Panicked)
}
if r.FixturePanicked != 0 {
value += fmt.Sprintf(", %d FIXTURE-PANICKED", r.FixturePanicked)
}
if r.Missed != 0 {
value += fmt.Sprintf(", %d MISSED", r.Missed)
}
if r.WorkDir != "" {
value += "\nWORK=" + r.WorkDir
}
return value
}

View File

@ -1,20 +0,0 @@
Copyright (C) 2013-2016 by Maxim Bublis <b@codemonkey.ru>
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,488 +0,0 @@
// Copyright (C) 2013-2015 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Package uuid provides implementation of Universally Unique Identifier (UUID).
// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and
// version 2 (as specified in DCE 1.1).
package uuid
import (
"bytes"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"fmt"
"hash"
"net"
"os"
"sync"
"time"
)
// UUID layout variants.
const (
VariantNCS = iota
VariantRFC4122
VariantMicrosoft
VariantFuture
)
// UUID DCE domains.
const (
DomainPerson = iota
DomainGroup
DomainOrg
)
// Difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970).
const epochStart = 122192928000000000
// Used in string method conversion
const dash byte = '-'
// UUID v1/v2 storage.
var (
storageMutex sync.Mutex
storageOnce sync.Once
epochFunc = unixTimeFunc
clockSequence uint16
lastTime uint64
hardwareAddr [6]byte
posixUID = uint32(os.Getuid())
posixGID = uint32(os.Getgid())
)
// String parse helpers.
var (
urnPrefix = []byte("urn:uuid:")
byteGroups = []int{8, 4, 4, 4, 12}
)
func initClockSequence() {
buf := make([]byte, 2)
safeRandom(buf)
clockSequence = binary.BigEndian.Uint16(buf)
}
func initHardwareAddr() {
interfaces, err := net.Interfaces()
if err == nil {
for _, iface := range interfaces {
if len(iface.HardwareAddr) >= 6 {
copy(hardwareAddr[:], iface.HardwareAddr)
return
}
}
}
// Initialize hardwareAddr randomly in case
// of real network interfaces absence
safeRandom(hardwareAddr[:])
// Set multicast bit as recommended in RFC 4122
hardwareAddr[0] |= 0x01
}
func initStorage() {
initClockSequence()
initHardwareAddr()
}
func safeRandom(dest []byte) {
if _, err := rand.Read(dest); err != nil {
panic(err)
}
}
// Returns difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and current time.
// This is default epoch calculation function.
func unixTimeFunc() uint64 {
return epochStart + uint64(time.Now().UnixNano()/100)
}
// UUID representation compliant with specification
// described in RFC 4122.
type UUID [16]byte
// NullUUID can be used with the standard sql package to represent a
// UUID value that can be NULL in the database
type NullUUID struct {
UUID UUID
Valid bool
}
// The nil UUID is special form of UUID that is specified to have all
// 128 bits set to zero.
var Nil = UUID{}
// Predefined namespace UUIDs.
var (
NamespaceDNS, _ = FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
NamespaceURL, _ = FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
NamespaceOID, _ = FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8")
NamespaceX500, _ = FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8")
)
// And returns result of binary AND of two UUIDs.
func And(u1 UUID, u2 UUID) UUID {
u := UUID{}
for i := 0; i < 16; i++ {
u[i] = u1[i] & u2[i]
}
return u
}
// Or returns result of binary OR of two UUIDs.
func Or(u1 UUID, u2 UUID) UUID {
u := UUID{}
for i := 0; i < 16; i++ {
u[i] = u1[i] | u2[i]
}
return u
}
// Equal returns true if u1 and u2 equals, otherwise returns false.
func Equal(u1 UUID, u2 UUID) bool {
return bytes.Equal(u1[:], u2[:])
}
// Version returns algorithm version used to generate UUID.
func (u UUID) Version() uint {
return uint(u[6] >> 4)
}
// Variant returns UUID layout variant.
func (u UUID) Variant() uint {
switch {
case (u[8] & 0x80) == 0x00:
return VariantNCS
case (u[8]&0xc0)|0x80 == 0x80:
return VariantRFC4122
case (u[8]&0xe0)|0xc0 == 0xc0:
return VariantMicrosoft
}
return VariantFuture
}
// Bytes returns bytes slice representation of UUID.
func (u UUID) Bytes() []byte {
return u[:]
}
// Returns canonical string representation of UUID:
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx.
func (u UUID) String() string {
buf := make([]byte, 36)
hex.Encode(buf[0:8], u[0:4])
buf[8] = dash
hex.Encode(buf[9:13], u[4:6])
buf[13] = dash
hex.Encode(buf[14:18], u[6:8])
buf[18] = dash
hex.Encode(buf[19:23], u[8:10])
buf[23] = dash
hex.Encode(buf[24:], u[10:])
return string(buf)
}
// SetVersion sets version bits.
func (u *UUID) SetVersion(v byte) {
u[6] = (u[6] & 0x0f) | (v << 4)
}
// SetVariant sets variant bits as described in RFC 4122.
func (u *UUID) SetVariant() {
u[8] = (u[8] & 0xbf) | 0x80
}
// MarshalText implements the encoding.TextMarshaler interface.
// The encoding is the same as returned by String.
func (u UUID) MarshalText() (text []byte, err error) {
text = []byte(u.String())
return
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// Following formats are supported:
// "6ba7b810-9dad-11d1-80b4-00c04fd430c8",
// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}",
// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8"
func (u *UUID) UnmarshalText(text []byte) (err error) {
if len(text) < 32 {
err = fmt.Errorf("uuid: UUID string too short: %s", text)
return
}
t := text[:]
braced := false
if bytes.Equal(t[:9], urnPrefix) {
t = t[9:]
} else if t[0] == '{' {
braced = true
t = t[1:]
}
b := u[:]
for i, byteGroup := range byteGroups {
if i > 0 && t[0] == '-' {
t = t[1:]
} else if i > 0 && t[0] != '-' {
err = fmt.Errorf("uuid: invalid string format")
return
}
if i == 2 {
if !bytes.Contains([]byte("012345"), []byte{t[0]}) {
err = fmt.Errorf("uuid: invalid version number: %s", t[0])
return
}
}
if len(t) < byteGroup {
err = fmt.Errorf("uuid: UUID string too short: %s", text)
return
}
if i == 4 && len(t) > byteGroup &&
((braced && t[byteGroup] != '}') || len(t[byteGroup:]) > 1 || !braced) {
err = fmt.Errorf("uuid: UUID string too long: %s", t)
return
}
_, err = hex.Decode(b[:byteGroup/2], t[:byteGroup])
if err != nil {
return
}
t = t[byteGroup:]
b = b[byteGroup/2:]
}
return
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (u UUID) MarshalBinary() (data []byte, err error) {
data = u.Bytes()
return
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It will return error if the slice isn't 16 bytes long.
func (u *UUID) UnmarshalBinary(data []byte) (err error) {
if len(data) != 16 {
err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data))
return
}
copy(u[:], data)
return
}
// Value implements the driver.Valuer interface.
func (u UUID) Value() (driver.Value, error) {
return u.String(), nil
}
// Scan implements the sql.Scanner interface.
// A 16-byte slice is handled by UnmarshalBinary, while
// a longer byte slice or a string is handled by UnmarshalText.
func (u *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
if len(src) == 16 {
return u.UnmarshalBinary(src)
}
return u.UnmarshalText(src)
case string:
return u.UnmarshalText([]byte(src))
}
return fmt.Errorf("uuid: cannot convert %T to UUID", src)
}
// Value implements the driver.Valuer interface.
func (u NullUUID) Value() (driver.Value, error) {
if !u.Valid {
return nil, nil
}
// Delegate to UUID Value function
return u.UUID.Value()
}
// Scan implements the sql.Scanner interface.
func (u *NullUUID) Scan(src interface{}) error {
if src == nil {
u.UUID, u.Valid = Nil, false
return nil
}
// Delegate to UUID Scan function
u.Valid = true
return u.UUID.Scan(src)
}
// FromBytes returns UUID converted from raw byte slice input.
// It will return error if the slice isn't 16 bytes long.
func FromBytes(input []byte) (u UUID, err error) {
err = u.UnmarshalBinary(input)
return
}
// FromBytesOrNil returns UUID converted from raw byte slice input.
// Same behavior as FromBytes, but returns a Nil UUID on error.
func FromBytesOrNil(input []byte) UUID {
uuid, err := FromBytes(input)
if err != nil {
return Nil
}
return uuid
}
// FromString returns UUID parsed from string input.
// Input is expected in a form accepted by UnmarshalText.
func FromString(input string) (u UUID, err error) {
err = u.UnmarshalText([]byte(input))
return
}
// FromStringOrNil returns UUID parsed from string input.
// Same behavior as FromString, but returns a Nil UUID on error.
func FromStringOrNil(input string) UUID {
uuid, err := FromString(input)
if err != nil {
return Nil
}
return uuid
}
// Returns UUID v1/v2 storage state.
// Returns epoch timestamp, clock sequence, and hardware address.
func getStorage() (uint64, uint16, []byte) {
storageOnce.Do(initStorage)
storageMutex.Lock()
defer storageMutex.Unlock()
timeNow := epochFunc()
// Clock changed backwards since last UUID generation.
// Should increase clock sequence.
if timeNow <= lastTime {
clockSequence++
}
lastTime = timeNow
return timeNow, clockSequence, hardwareAddr[:]
}
// NewV1 returns UUID based on current timestamp and MAC address.
func NewV1() UUID {
u := UUID{}
timeNow, clockSeq, hardwareAddr := getStorage()
binary.BigEndian.PutUint32(u[0:], uint32(timeNow))
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
copy(u[10:], hardwareAddr)
u.SetVersion(1)
u.SetVariant()
return u
}
// NewV2 returns DCE Security UUID based on POSIX UID/GID.
func NewV2(domain byte) UUID {
u := UUID{}
timeNow, clockSeq, hardwareAddr := getStorage()
switch domain {
case DomainPerson:
binary.BigEndian.PutUint32(u[0:], posixUID)
case DomainGroup:
binary.BigEndian.PutUint32(u[0:], posixGID)
}
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
u[9] = domain
copy(u[10:], hardwareAddr)
u.SetVersion(2)
u.SetVariant()
return u
}
// NewV3 returns UUID based on MD5 hash of namespace UUID and name.
func NewV3(ns UUID, name string) UUID {
u := newFromHash(md5.New(), ns, name)
u.SetVersion(3)
u.SetVariant()
return u
}
// NewV4 returns random generated UUID.
func NewV4() UUID {
u := UUID{}
safeRandom(u[:])
u.SetVersion(4)
u.SetVariant()
return u
}
// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name.
func NewV5(ns UUID, name string) UUID {
u := newFromHash(sha1.New(), ns, name)
u.SetVersion(5)
u.SetVariant()
return u
}
// Returns UUID based on hashing of namespace UUID and name.
func newFromHash(h hash.Hash, ns UUID, name string) UUID {
u := UUID{}
h.Write(ns[:])
h.Write([]byte(name))
copy(u[:], h.Sum(nil))
return u
}

View File

@ -1,20 +0,0 @@
The MIT License (MIT)
Copyright (c) 2014 siddontang
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,27 +0,0 @@
package hack
import (
"reflect"
"unsafe"
)
// no copy to change slice to string
// use your own risk
func String(b []byte) (s string) {
pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
pstring.Data = pbytes.Data
pstring.Len = pbytes.Len
return
}
// no copy to change string to slice
// use your own risk
func Slice(s string) (b []byte) {
pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
pbytes.Data = pstring.Data
pbytes.Len = pstring.Len
pbytes.Cap = pstring.Len
return
}

View File

@ -1,39 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ioutil2
import (
"io"
"io/ioutil"
"os"
"path"
)
// Write file to temp and atomically move when everything else succeeds.
func WriteFileAtomic(filename string, data []byte, perm os.FileMode) error {
dir, name := path.Split(filename)
f, err := ioutil.TempFile(dir, name)
if err != nil {
return err
}
n, err := f.Write(data)
f.Close()
if err == nil && n < len(data) {
err = io.ErrShortWrite
} else {
err = os.Chmod(f.Name(), perm)
}
if err != nil {
os.Remove(f.Name())
return err
}
return os.Rename(f.Name(), filename)
}
// Check file exists or not
func FileExists(name string) bool {
_, err := os.Stat(name)
return !os.IsNotExist(err)
}

View File

@ -1,69 +0,0 @@
package ioutil2
import (
"errors"
"io"
)
var ErrExceedLimit = errors.New("write exceed limit")
func NewSectionWriter(w io.WriterAt, off int64, n int64) *SectionWriter {
return &SectionWriter{w, off, off, off + n}
}
type SectionWriter struct {
w io.WriterAt
base int64
off int64
limit int64
}
func (s *SectionWriter) Write(p []byte) (n int, err error) {
if s.off >= s.limit {
return 0, ErrExceedLimit
}
if max := s.limit - s.off; int64(len(p)) > max {
return 0, ErrExceedLimit
}
n, err = s.w.WriteAt(p, s.off)
s.off += int64(n)
return
}
var errWhence = errors.New("Seek: invalid whence")
var errOffset = errors.New("Seek: invalid offset")
func (s *SectionWriter) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, errWhence
case 0:
offset += s.base
case 1:
offset += s.off
case 2:
offset += s.limit
}
if offset < s.base {
return 0, errOffset
}
s.off = offset
return offset - s.base, nil
}
func (s *SectionWriter) WriteAt(p []byte, off int64) (n int, err error) {
if off < 0 || off >= s.limit-s.base {
return 0, errOffset
}
off += s.base
if max := s.limit - off; int64(len(p)) > max {
return 0, ErrExceedLimit
}
return s.w.WriteAt(p, off)
}
// Size returns the size of the section in bytes.
func (s *SectionWriter) Size() int64 { return s.limit - s.base }

View File

@ -1,146 +0,0 @@
// Copyright 2013, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync2
import (
"sync"
"sync/atomic"
"time"
)
type AtomicInt32 int32
func (i *AtomicInt32) Add(n int32) int32 {
return atomic.AddInt32((*int32)(i), n)
}
func (i *AtomicInt32) Set(n int32) {
atomic.StoreInt32((*int32)(i), n)
}
func (i *AtomicInt32) Get() int32 {
return atomic.LoadInt32((*int32)(i))
}
func (i *AtomicInt32) CompareAndSwap(oldval, newval int32) (swapped bool) {
return atomic.CompareAndSwapInt32((*int32)(i), oldval, newval)
}
type AtomicUint32 uint32
func (i *AtomicUint32) Add(n uint32) uint32 {
return atomic.AddUint32((*uint32)(i), n)
}
func (i *AtomicUint32) Set(n uint32) {
atomic.StoreUint32((*uint32)(i), n)
}
func (i *AtomicUint32) Get() uint32 {
return atomic.LoadUint32((*uint32)(i))
}
func (i *AtomicUint32) CompareAndSwap(oldval, newval uint32) (swapped bool) {
return atomic.CompareAndSwapUint32((*uint32)(i), oldval, newval)
}
type AtomicInt64 int64
func (i *AtomicInt64) Add(n int64) int64 {
return atomic.AddInt64((*int64)(i), n)
}
func (i *AtomicInt64) Set(n int64) {
atomic.StoreInt64((*int64)(i), n)
}
func (i *AtomicInt64) Get() int64 {
return atomic.LoadInt64((*int64)(i))
}
func (i *AtomicInt64) CompareAndSwap(oldval, newval int64) (swapped bool) {
return atomic.CompareAndSwapInt64((*int64)(i), oldval, newval)
}
type AtomicUint64 uint64
func (i *AtomicUint64) Add(n uint64) uint64 {
return atomic.AddUint64((*uint64)(i), n)
}
func (i *AtomicUint64) Set(n uint64) {
atomic.StoreUint64((*uint64)(i), n)
}
func (i *AtomicUint64) Get() uint64 {
return atomic.LoadUint64((*uint64)(i))
}
func (i *AtomicUint64) CompareAndSwap(oldval, newval uint64) (swapped bool) {
return atomic.CompareAndSwapUint64((*uint64)(i), oldval, newval)
}
type AtomicDuration int64
func (d *AtomicDuration) Add(duration time.Duration) time.Duration {
return time.Duration(atomic.AddInt64((*int64)(d), int64(duration)))
}
func (d *AtomicDuration) Set(duration time.Duration) {
atomic.StoreInt64((*int64)(d), int64(duration))
}
func (d *AtomicDuration) Get() time.Duration {
return time.Duration(atomic.LoadInt64((*int64)(d)))
}
func (d *AtomicDuration) CompareAndSwap(oldval, newval time.Duration) (swapped bool) {
return atomic.CompareAndSwapInt64((*int64)(d), int64(oldval), int64(newval))
}
// AtomicString gives you atomic-style APIs for string, but
// it's only a convenience wrapper that uses a mutex. So, it's
// not as efficient as the rest of the atomic types.
type AtomicString struct {
mu sync.Mutex
str string
}
func (s *AtomicString) Set(str string) {
s.mu.Lock()
s.str = str
s.mu.Unlock()
}
func (s *AtomicString) Get() string {
s.mu.Lock()
str := s.str
s.mu.Unlock()
return str
}
func (s *AtomicString) CompareAndSwap(oldval, newval string) (swqpped bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.str == oldval {
s.str = newval
return true
}
return false
}
type AtomicBool int32
func (b *AtomicBool) Set(v bool) {
if v {
atomic.StoreInt32((*int32)(b), 1)
} else {
atomic.StoreInt32((*int32)(b), 0)
}
}
func (b *AtomicBool) Get() bool {
return atomic.LoadInt32((*int32)(b)) == 1
}

View File

@ -1,65 +0,0 @@
package sync2
import (
"sync"
"sync/atomic"
"time"
)
func NewSemaphore(initialCount int) *Semaphore {
res := &Semaphore{
counter: int64(initialCount),
}
res.cond.L = &res.lock
return res
}
type Semaphore struct {
lock sync.Mutex
cond sync.Cond
counter int64
}
func (s *Semaphore) Release() {
s.lock.Lock()
s.counter += 1
if s.counter >= 0 {
s.cond.Signal()
}
s.lock.Unlock()
}
func (s *Semaphore) Acquire() {
s.lock.Lock()
for s.counter < 1 {
s.cond.Wait()
}
s.counter -= 1
s.lock.Unlock()
}
func (s *Semaphore) AcquireTimeout(timeout time.Duration) bool {
done := make(chan bool, 1)
// Gate used to communicate between the threads and decide what the result
// is. If the main thread decides, we have timed out, otherwise we succeed.
decided := new(int32)
go func() {
s.Acquire()
if atomic.SwapInt32(decided, 1) == 0 {
done <- true
} else {
// If we already decided the result, and this thread did not win
s.Release()
}
}()
select {
case <-done:
return true
case <-time.NewTimer(timeout).C:
if atomic.SwapInt32(decided, 1) == 1 {
// The other thread already decided the result
return true
}
return false
}
}

View File

@ -1,27 +0,0 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,22 +0,0 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

View File

@ -1,447 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package context defines the Context type, which carries deadlines,
// cancelation signals, and other request-scoped values across API boundaries
// and between processes.
//
// Incoming requests to a server should create a Context, and outgoing calls to
// servers should accept a Context. The chain of function calls between must
// propagate the Context, optionally replacing it with a modified copy created
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
//
// Do not store Contexts inside a struct type; instead, pass a Context
// explicitly to each function that needs it. The Context should be the first
// parameter, typically named ctx:
//
// func DoSomething(ctx context.Context, arg Arg) error {
// // ... use ctx ...
// }
//
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
// if you are unsure about which Context to use.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The same Context may be passed to functions running in different goroutines;
// Contexts are safe for simultaneous use by multiple goroutines.
//
// See http://blog.golang.org/context for example code for a server that uses
// Contexts.
package context // import "golang.org/x/net/context"
import (
"errors"
"fmt"
"sync"
"time"
)
// A Context carries a deadline, a cancelation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
//
// WithCancel arranges for Done to be closed when cancel is called;
// WithDeadline arranges for Done to be closed when the deadline
// expires; WithTimeout arranges for Done to be closed when the timeout
// elapses.
//
// Done is provided for use in select statements:
//
// // Stream generates values with DoSomething and sends them to out
// // until DoSomething returns an error or ctx.Done is closed.
// func Stream(ctx context.Context, out <-chan Value) error {
// for {
// v, err := DoSomething(ctx)
// if err != nil {
// return err
// }
// select {
// case <-ctx.Done():
// return ctx.Err()
// case out <- v:
// }
// }
// }
//
// See http://blog.golang.org/pipelines for more examples of how to use
// a Done channel for cancelation.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
//
// A key identifies a specific value in a Context. Functions that wish
// to store values in Context typically allocate a key in a global
// variable then use that key as the argument to context.WithValue and
// Context.Value. A key can be any type that supports equality;
// packages should define keys as an unexported type to avoid
// collisions.
//
// Packages that define a Context key should provide type-safe accessors
// for the values stores using that key:
//
// // Package user defines a User type that's stored in Contexts.
// package user
//
// import "golang.org/x/net/context"
//
// // User is the type of value stored in the Contexts.
// type User struct {...}
//
// // key is an unexported type for keys defined in this package.
// // This prevents collisions with keys defined in other packages.
// type key int
//
// // userKey is the key for user.User values in Contexts. It is
// // unexported; clients use user.NewContext and user.FromContext
// // instead of using this key directly.
// var userKey key = 0
//
// // NewContext returns a new Context that carries value u.
// func NewContext(ctx context.Context, u *User) context.Context {
// return context.WithValue(ctx, userKey, u)
// }
//
// // FromContext returns the User value stored in ctx, if any.
// func FromContext(ctx context.Context) (*User, bool) {
// u, ok := ctx.Value(userKey).(*User)
// return u, ok
// }
Value(key interface{}) interface{}
}
// Canceled is the error returned by Context.Err when the context is canceled.
var Canceled = errors.New("context canceled")
// DeadlineExceeded is the error returned by Context.Err when the context's
// deadline passes.
var DeadlineExceeded = errors.New("context deadline exceeded")
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case background:
return "context.Background"
case todo:
return "context.TODO"
}
return "unknown empty Context"
}
var (
background = new(emptyCtx)
todo = new(emptyCtx)
)
// Background returns a non-nil, empty Context. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function,
// initialization, and tests, and as the top-level Context for incoming
// requests.
func Background() Context {
return background
}
// TODO returns a non-nil, empty Context. Code should use context.TODO when
// it's unclear which Context to use or it is not yet available (because the
// surrounding function has not yet been extended to accept a Context
// parameter). TODO is recognized by static analysis tools that determine
// whether Contexts are propagated correctly in a program.
func TODO() Context {
return todo
}
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc func()
// WithCancel returns a copy of parent with a new Done channel. The returned
// context's Done channel is closed when the returned cancel function is called
// or when the parent context's Done channel is closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := newCancelCtx(parent)
propagateCancel(parent, &c)
return &c, func() { c.cancel(true, Canceled) }
}
// newCancelCtx returns an initialized cancelCtx.
func newCancelCtx(parent Context) cancelCtx {
return cancelCtx{
Context: parent,
done: make(chan struct{}),
}
}
// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) {
if parent.Done() == nil {
return // parent is never canceled
}
if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err)
} else {
if p.children == nil {
p.children = make(map[canceler]bool)
}
p.children[child] = true
}
p.mu.Unlock()
} else {
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
case <-child.Done():
}
}()
}
}
// parentCancelCtx follows a chain of parent references until it finds a
// *cancelCtx. This function understands how each of the concrete types in this
// package represents its parent.
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
for {
switch c := parent.(type) {
case *cancelCtx:
return c, true
case *timerCtx:
return &c.cancelCtx, true
case *valueCtx:
parent = c.Context
default:
return nil, false
}
}
}
// removeChild removes a context from its parent.
func removeChild(parent Context, child canceler) {
p, ok := parentCancelCtx(parent)
if !ok {
return
}
p.mu.Lock()
if p.children != nil {
delete(p.children, child)
}
p.mu.Unlock()
}
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
cancel(removeFromParent bool, err error)
Done() <-chan struct{}
}
// A cancelCtx can be canceled. When canceled, it also cancels any children
// that implement canceler.
type cancelCtx struct {
Context
done chan struct{} // closed by the first cancel call.
mu sync.Mutex
children map[canceler]bool // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
}
func (c *cancelCtx) Done() <-chan struct{} {
return c.done
}
func (c *cancelCtx) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}
func (c *cancelCtx) String() string {
return fmt.Sprintf("%v.WithCancel", c.Context)
}
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
close(c.done)
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err)
}
c.children = nil
c.mu.Unlock()
if removeFromParent {
removeChild(c.Context, c)
}
}
// WithDeadline returns a copy of the parent context with the deadline adjusted
// to be no later than d. If the parent's deadline is already earlier than d,
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
// context's Done channel is closed when the deadline expires, when the returned
// cancel function is called, or when the parent context's Done channel is
// closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) {
// The current deadline is already sooner than the new one.
return WithCancel(parent)
}
c := &timerCtx{
cancelCtx: newCancelCtx(parent),
deadline: deadline,
}
propagateCancel(parent, c)
d := deadline.Sub(time.Now())
if d <= 0 {
c.cancel(true, DeadlineExceeded) // deadline has already passed
return c, func() { c.cancel(true, Canceled) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(d, func() {
c.cancel(true, DeadlineExceeded)
})
}
return c, func() { c.cancel(true, Canceled) }
}
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
}
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
return c.deadline, true
}
func (c *timerCtx) String() string {
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now()))
}
func (c *timerCtx) cancel(removeFromParent bool, err error) {
c.cancelCtx.cancel(false, err)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
}
c.mu.Lock()
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
c.mu.Unlock()
}
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete:
//
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
// defer cancel() // releases resources if slowOperation completes before timeout elapses
// return slowOperation(ctx)
// }
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
// WithValue returns a copy of parent in which the value associated with key is
// val.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
func WithValue(parent Context, key interface{}, val interface{}) Context {
return &valueCtx{parent, key, val}
}
// A valueCtx carries a key-value pair. It implements Value for that key and
// delegates all other calls to the embedded Context.
type valueCtx struct {
Context
key, val interface{}
}
func (c *valueCtx) String() string {
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val)
}
func (c *valueCtx) Value(key interface{}) interface{} {
if c.key == key {
return c.val
}
return c.Context.Value(key)
}

View File

@ -1,26 +1,25 @@
package canal
import (
"context"
"fmt"
"io/ioutil"
"os"
"path"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go-mysql/client"
"github.com/siddontang/go-mysql/dump"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/replication"
"github.com/siddontang/go-mysql/schema"
"github.com/siddontang/go/sync2"
)
var errCanalClosed = errors.New("canal was closed")
// Canal can sync your MySQL data into everywhere, like Elasticsearch, Redis, etc...
// MySQL must open row format for binlog
type Canal struct {
@ -30,48 +29,49 @@ type Canal struct {
master *masterInfo
dumper *dump.Dumper
dumped bool
dumpDoneCh chan struct{}
syncer *replication.BinlogSyncer
rsLock sync.Mutex
rsHandlers []RowsEventHandler
eventHandler EventHandler
connLock sync.Mutex
conn *client.Conn
wg sync.WaitGroup
tableLock sync.RWMutex
tables map[string]*schema.Table
errorTablesGetTime map[string]time.Time
tableLock sync.Mutex
tables map[string]*schema.Table
tableMatchCache map[string]bool
includeTableRegex []*regexp.Regexp
excludeTableRegex []*regexp.Regexp
quit chan struct{}
closed sync2.AtomicBool
ctx context.Context
cancel context.CancelFunc
}
// canal will retry fetching unknown table's meta after UnknownTableRetryPeriod
var UnknownTableRetryPeriod = time.Second * time.Duration(10)
var ErrExcludedTable = errors.New("excluded table meta")
func NewCanal(cfg *Config) (*Canal, error) {
c := new(Canal)
c.cfg = cfg
c.closed.Set(false)
c.quit = make(chan struct{})
os.MkdirAll(cfg.DataDir, 0755)
c.ctx, c.cancel = context.WithCancel(context.Background())
c.dumpDoneCh = make(chan struct{})
c.rsHandlers = make([]RowsEventHandler, 0, 4)
c.eventHandler = &DummyEventHandler{}
c.tables = make(map[string]*schema.Table)
if c.cfg.DiscardNoMetaRowEvent {
c.errorTablesGetTime = make(map[string]time.Time)
}
c.master = &masterInfo{}
var err error
if c.master, err = loadMasterInfo(c.masterInfoPath()); err != nil {
return nil, errors.Trace(err)
} else if len(c.master.Addr) != 0 && c.master.Addr != c.cfg.Addr {
log.Infof("MySQL addr %s in old master.info, but new %s, reset", c.master.Addr, c.cfg.Addr)
// may use another MySQL, reset
c.master = &masterInfo{}
}
c.master.Addr = c.cfg.Addr
if err := c.prepareDumper(); err != nil {
if err = c.prepareDumper(); err != nil {
return nil, errors.Trace(err)
}
@ -83,6 +83,33 @@ func NewCanal(cfg *Config) (*Canal, error) {
return nil, errors.Trace(err)
}
// init table filter
if n := len(c.cfg.IncludeTableRegex); n > 0 {
c.includeTableRegex = make([]*regexp.Regexp, n)
for i, val := range c.cfg.IncludeTableRegex {
reg, err := regexp.Compile(val)
if err != nil {
return nil, errors.Trace(err)
}
c.includeTableRegex[i] = reg
}
}
if n := len(c.cfg.ExcludeTableRegex); n > 0 {
c.excludeTableRegex = make([]*regexp.Regexp, n)
for i, val := range c.cfg.ExcludeTableRegex {
reg, err := regexp.Compile(val)
if err != nil {
return nil, errors.Trace(err)
}
c.excludeTableRegex[i] = reg
}
}
if c.includeTableRegex != nil || c.excludeTableRegex != nil {
c.tableMatchCache = make(map[string]bool)
}
return c, nil
}
@ -114,6 +141,15 @@ func (c *Canal) prepareDumper() error {
c.dumper.AddTables(tableDB, tables...)
}
charset := c.cfg.Charset
c.dumper.SetCharset(charset)
c.dumper.SetWhere(c.cfg.Dump.Where)
c.dumper.SkipMasterData(c.cfg.Dump.SkipMasterData)
c.dumper.SetMaxAllowedPacket(c.cfg.Dump.MaxAllowedPacketMB)
// Use hex blob for mysqldump
c.dumper.SetHexBlob(true)
for _, ignoreTable := range c.cfg.Dump.IgnoreTables {
if seps := strings.Split(ignoreTable, ","); len(seps) == 2 {
c.dumper.AddIgnoreTables(seps[0], seps[1])
@ -129,92 +165,208 @@ func (c *Canal) prepareDumper() error {
return nil
}
func (c *Canal) Start() error {
c.wg.Add(1)
go c.run()
// Run will first try to dump all data from MySQL master `mysqldump`,
// then sync from the binlog position in the dump data.
// It will run forever until meeting an error or Canal closed.
func (c *Canal) Run() error {
return c.run()
}
return nil
// RunFrom will sync from the binlog position directly, ignore mysqldump.
func (c *Canal) RunFrom(pos mysql.Position) error {
c.master.Update(pos)
return c.Run()
}
func (c *Canal) StartFromGTID(set mysql.GTIDSet) error {
c.master.UpdateGTIDSet(set)
return c.Run()
}
// Dump all data from MySQL master `mysqldump`, ignore sync binlog.
func (c *Canal) Dump() error {
if c.dumped {
return errors.New("the method Dump can't be called twice")
}
c.dumped = true
defer close(c.dumpDoneCh)
return c.dump()
}
func (c *Canal) run() error {
defer c.wg.Done()
defer func() {
c.cancel()
}()
if err := c.tryDump(); err != nil {
log.Errorf("canal dump mysql err: %v", err)
return errors.Trace(err)
c.master.UpdateTimestamp(uint32(time.Now().Unix()))
if !c.dumped {
c.dumped = true
err := c.tryDump()
close(c.dumpDoneCh)
if err != nil {
log.Errorf("canal dump mysql err: %v", err)
return errors.Trace(err)
}
}
close(c.dumpDoneCh)
if err := c.startSyncBinlog(); err != nil {
if !c.isClosed() {
log.Errorf("canal start sync binlog err: %v", err)
}
if err := c.runSyncBinlog(); err != nil {
log.Errorf("canal start sync binlog err: %v", err)
return errors.Trace(err)
}
return nil
}
func (c *Canal) isClosed() bool {
return c.closed.Get()
}
func (c *Canal) Close() {
log.Infof("close canal")
log.Infof("closing canal")
c.m.Lock()
defer c.m.Unlock()
if c.isClosed() {
return
}
c.closed.Set(true)
close(c.quit)
c.cancel()
c.connLock.Lock()
c.conn.Close()
c.conn = nil
c.connLock.Unlock()
c.syncer.Close()
if c.syncer != nil {
c.syncer.Close()
c.syncer = nil
}
c.master.Close()
c.wg.Wait()
c.eventHandler.OnPosSynced(c.master.Position(), true)
}
func (c *Canal) WaitDumpDone() <-chan struct{} {
return c.dumpDoneCh
}
func (c *Canal) Ctx() context.Context {
return c.ctx
}
func (c *Canal) checkTableMatch(key string) bool {
// no filter, return true
if c.tableMatchCache == nil {
return true
}
c.tableLock.RLock()
rst, ok := c.tableMatchCache[key]
c.tableLock.RUnlock()
if ok {
// cache hit
return rst
}
matchFlag := false
// check include
if c.includeTableRegex != nil {
for _, reg := range c.includeTableRegex {
if reg.MatchString(key) {
matchFlag = true
break
}
}
}
// check exclude
if matchFlag && c.excludeTableRegex != nil {
for _, reg := range c.excludeTableRegex {
if reg.MatchString(key) {
matchFlag = false
break
}
}
}
c.tableLock.Lock()
c.tableMatchCache[key] = matchFlag
c.tableLock.Unlock()
return matchFlag
}
func (c *Canal) GetTable(db string, table string) (*schema.Table, error) {
key := fmt.Sprintf("%s.%s", db, table)
c.tableLock.Lock()
// if table is excluded, return error and skip parsing event or dump
if !c.checkTableMatch(key) {
return nil, ErrExcludedTable
}
c.tableLock.RLock()
t, ok := c.tables[key]
c.tableLock.Unlock()
c.tableLock.RUnlock()
if ok {
return t, nil
}
if c.cfg.DiscardNoMetaRowEvent {
c.tableLock.RLock()
lastTime, ok := c.errorTablesGetTime[key]
c.tableLock.RUnlock()
if ok && time.Now().Sub(lastTime) < UnknownTableRetryPeriod {
return nil, schema.ErrMissingTableMeta
}
}
t, err := schema.NewTable(c, db, table)
if err != nil {
return nil, errors.Trace(err)
// check table not exists
if ok, err1 := schema.IsTableExist(c, db, table); err1 == nil && !ok {
return nil, schema.ErrTableNotExist
}
// work around : RDS HAHeartBeat
// ref : https://github.com/alibaba/canal/blob/master/parse/src/main/java/com/alibaba/otter/canal/parse/inbound/mysql/dbsync/LogEventConvert.java#L385
// issue : https://github.com/alibaba/canal/issues/222
// This is a common error in RDS that canal can't get HAHealthCheckSchema's meta, so we mock a table meta.
// If canal just skip and log error, as RDS HA heartbeat interval is very short, so too many HAHeartBeat errors will be logged.
if key == schema.HAHealthCheckSchema {
// mock ha_health_check meta
ta := &schema.Table{
Schema: db,
Name: table,
Columns: make([]schema.TableColumn, 0, 2),
Indexes: make([]*schema.Index, 0),
}
ta.AddColumn("id", "bigint(20)", "", "")
ta.AddColumn("type", "char(1)", "", "")
c.tableLock.Lock()
c.tables[key] = ta
c.tableLock.Unlock()
return ta, nil
}
// if DiscardNoMetaRowEvent is true, we just log this error
if c.cfg.DiscardNoMetaRowEvent {
c.tableLock.Lock()
c.errorTablesGetTime[key] = time.Now()
c.tableLock.Unlock()
// log error and return ErrMissingTableMeta
log.Errorf("canal get table meta err: %v", errors.Trace(err))
return nil, schema.ErrMissingTableMeta
}
return nil, err
}
c.tableLock.Lock()
c.tables[key] = t
if c.cfg.DiscardNoMetaRowEvent {
// if get table info success, delete this key from errorTablesGetTime
delete(c.errorTablesGetTime, key)
}
c.tableLock.Unlock()
return t, nil
}
// ClearTableCache clear table cache
func (c *Canal) ClearTableCache(db []byte, table []byte) {
key := fmt.Sprintf("%s.%s", db, table)
c.tableLock.Lock()
delete(c.tables, key)
if c.cfg.DiscardNoMetaRowEvent {
delete(c.errorTablesGetTime, key)
}
c.tableLock.Unlock()
}
// Check MySQL binlog row image, must be in FULL, MINIMAL, NOBLOB
func (c *Canal) CheckBinlogRowImage(image string) error {
// need to check MySQL binlog row image? full, minimal or noblob?
@ -246,23 +398,34 @@ func (c *Canal) checkBinlogRowFormat() error {
}
func (c *Canal) prepareSyncer() error {
seps := strings.Split(c.cfg.Addr, ":")
if len(seps) != 2 {
return errors.Errorf("invalid mysql addr format %s, must host:port", c.cfg.Addr)
}
port, err := strconv.ParseUint(seps[1], 10, 16)
if err != nil {
return errors.Trace(err)
}
cfg := replication.BinlogSyncerConfig{
ServerID: c.cfg.ServerID,
Flavor: c.cfg.Flavor,
Host: seps[0],
Port: uint16(port),
User: c.cfg.User,
Password: c.cfg.Password,
ServerID: c.cfg.ServerID,
Flavor: c.cfg.Flavor,
User: c.cfg.User,
Password: c.cfg.Password,
Charset: c.cfg.Charset,
HeartbeatPeriod: c.cfg.HeartbeatPeriod,
ReadTimeout: c.cfg.ReadTimeout,
UseDecimal: c.cfg.UseDecimal,
ParseTime: c.cfg.ParseTime,
SemiSyncEnabled: c.cfg.SemiSyncEnabled,
}
if strings.Contains(c.cfg.Addr, "/") {
cfg.Host = c.cfg.Addr
} else {
seps := strings.Split(c.cfg.Addr, ":")
if len(seps) != 2 {
return errors.Errorf("invalid mysql addr format %s, must host:port", c.cfg.Addr)
}
port, err := strconv.ParseUint(seps[1], 10, 16)
if err != nil {
return errors.Trace(err)
}
cfg.Host = seps[0]
cfg.Port = uint16(port)
}
c.syncer = replication.NewBinlogSyncer(cfg)
@ -270,10 +433,6 @@ func (c *Canal) prepareSyncer() error {
return nil
}
func (c *Canal) masterInfoPath() string {
return path.Join(c.cfg.DataDir, "master.info")
}
// Execute a SQL
func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err error) {
c.connLock.Lock()
@ -303,5 +462,13 @@ func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err
}
func (c *Canal) SyncedPosition() mysql.Position {
return c.master.Pos()
return c.master.Position()
}
func (c *Canal) SyncedTimestamp() uint32 {
return c.master.timestamp
}
func (c *Canal) SyncedGTIDSet() mysql.GTIDSet {
return c.master.GTIDSet()
}

169
vendor/github.com/siddontang/go-mysql/canal/canal_test.go generated vendored Normal file → Executable file
View File

@ -1,13 +1,15 @@
package canal
import (
"bytes"
"flag"
"fmt"
"os"
"testing"
"time"
"github.com/ngaut/log"
"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go-mysql/mysql"
)
@ -27,19 +29,28 @@ func (s *canalTestSuite) SetUpSuite(c *C) {
cfg := NewDefaultConfig()
cfg.Addr = fmt.Sprintf("%s:3306", *testHost)
cfg.User = "root"
cfg.HeartbeatPeriod = 200 * time.Millisecond
cfg.ReadTimeout = 300 * time.Millisecond
cfg.Dump.ExecutionPath = "mysqldump"
cfg.Dump.TableDB = "test"
cfg.Dump.Tables = []string{"canal_test"}
cfg.Dump.Where = "id>0"
os.RemoveAll(cfg.DataDir)
// include & exclude config
cfg.IncludeTableRegex = make([]string, 1)
cfg.IncludeTableRegex[0] = ".*\\.canal_test"
cfg.ExcludeTableRegex = make([]string, 2)
cfg.ExcludeTableRegex[0] = "mysql\\..*"
cfg.ExcludeTableRegex[1] = ".*\\..*_inner"
var err error
s.c, err = NewCanal(cfg)
c.Assert(err, IsNil)
s.execute(c, "DROP TABLE IF EXISTS test.canal_test")
sql := `
CREATE TABLE IF NOT EXISTS test.canal_test (
id int AUTO_INCREMENT,
id int AUTO_INCREMENT,
content blob DEFAULT NULL,
name varchar(100),
PRIMARY KEY(id)
)ENGINE=innodb;
@ -48,16 +59,22 @@ func (s *canalTestSuite) SetUpSuite(c *C) {
s.execute(c, sql)
s.execute(c, "DELETE FROM test.canal_test")
s.execute(c, "INSERT INTO test.canal_test (name) VALUES (?), (?), (?)", "a", "b", "c")
s.execute(c, "INSERT INTO test.canal_test (content, name) VALUES (?, ?), (?, ?), (?, ?)", "1", "a", `\0\ndsfasdf`, "b", "", "c")
s.execute(c, "SET GLOBAL binlog_format = 'ROW'")
s.c.RegRowsEventHandler(&testRowsEventHandler{})
err = s.c.Start()
c.Assert(err, IsNil)
s.c.SetEventHandler(&testEventHandler{c: c})
go func() {
err = s.c.Run()
c.Assert(err, IsNil)
}()
}
func (s *canalTestSuite) TearDownSuite(c *C) {
// To test the heartbeat and read timeout,so need to sleep 1 seconds without data transmission
c.Logf("Start testing the heartbeat and read timeout")
time.Sleep(time.Second)
if s.c != nil {
s.c.Close()
s.c = nil
@ -70,16 +87,19 @@ func (s *canalTestSuite) execute(c *C, query string, args ...interface{}) *mysql
return r
}
type testRowsEventHandler struct {
type testEventHandler struct {
DummyEventHandler
c *C
}
func (h *testRowsEventHandler) Do(e *RowsEvent) error {
log.Infof("%s %v\n", e.Action, e.Rows)
func (h *testEventHandler) OnRow(e *RowsEvent) error {
log.Infof("OnRow %s %v\n", e.Action, e.Rows)
return nil
}
func (h *testRowsEventHandler) String() string {
return "testRowsEventHandler"
func (h *testEventHandler) String() string {
return "testEventHandler"
}
func (s *canalTestSuite) TestCanal(c *C) {
@ -88,7 +108,126 @@ func (s *canalTestSuite) TestCanal(c *C) {
for i := 1; i < 10; i++ {
s.execute(c, "INSERT INTO test.canal_test (name) VALUES (?)", fmt.Sprintf("%d", i))
}
s.execute(c, "ALTER TABLE test.canal_test ADD `age` INT(5) NOT NULL AFTER `name`")
s.execute(c, "INSERT INTO test.canal_test (name,age) VALUES (?,?)", "d", "18")
err := s.c.CatchMasterPos(100)
err := s.c.CatchMasterPos(10 * time.Second)
c.Assert(err, IsNil)
}
func (s *canalTestSuite) TestCanalFilter(c *C) {
// included
sch, err := s.c.GetTable("test", "canal_test")
c.Assert(err, IsNil)
c.Assert(sch, NotNil)
_, err = s.c.GetTable("not_exist_db", "canal_test")
c.Assert(errors.Trace(err), Not(Equals), ErrExcludedTable)
// excluded
sch, err = s.c.GetTable("test", "canal_test_inner")
c.Assert(errors.Cause(err), Equals, ErrExcludedTable)
c.Assert(sch, IsNil)
sch, err = s.c.GetTable("mysql", "canal_test")
c.Assert(errors.Cause(err), Equals, ErrExcludedTable)
c.Assert(sch, IsNil)
sch, err = s.c.GetTable("not_exist_db", "not_canal_test")
c.Assert(errors.Cause(err), Equals, ErrExcludedTable)
c.Assert(sch, IsNil)
}
func TestCreateTableExp(t *testing.T) {
cases := []string{
"CREATE TABLE `mydb.mytable` (`id` int(10)) ENGINE=InnoDB",
"CREATE TABLE `mytable` (`id` int(10)) ENGINE=InnoDB",
"CREATE TABLE IF NOT EXISTS `mytable` (`id` int(10)) ENGINE=InnoDB",
"CREATE TABLE IF NOT EXISTS mytable (`id` int(10)) ENGINE=InnoDB",
}
table := []byte("mytable")
db := []byte("mydb")
for _, s := range cases {
m := expCreateTable.FindSubmatch([]byte(s))
mLen := len(m)
if m == nil || !bytes.Equal(m[mLen-1], table) || (len(m[mLen-2]) > 0 && !bytes.Equal(m[mLen-2], db)) {
t.Fatalf("TestCreateTableExp: case %s failed\n", s)
}
}
}
func TestAlterTableExp(t *testing.T) {
cases := []string{
"ALTER TABLE `mydb`.`mytable` ADD `field2` DATE NULL AFTER `field1`;",
"ALTER TABLE `mytable` ADD `field2` DATE NULL AFTER `field1`;",
"ALTER TABLE mydb.mytable ADD `field2` DATE NULL AFTER `field1`;",
"ALTER TABLE mytable ADD `field2` DATE NULL AFTER `field1`;",
"ALTER TABLE mydb.mytable ADD field2 DATE NULL AFTER `field1`;",
}
table := []byte("mytable")
db := []byte("mydb")
for _, s := range cases {
m := expAlterTable.FindSubmatch([]byte(s))
mLen := len(m)
if m == nil || !bytes.Equal(m[mLen-1], table) || (len(m[mLen-2]) > 0 && !bytes.Equal(m[mLen-2], db)) {
t.Fatalf("TestAlterTableExp: case %s failed\n", s)
}
}
}
func TestRenameTableExp(t *testing.T) {
cases := []string{
"rename table `mydb`.`mytable` to `mydb`.`mytable1`",
"rename table `mytable` to `mytable1`",
"rename table mydb.mytable to mydb.mytable1",
"rename table mytable to mytable1",
"rename table `mydb`.`mytable` to `mydb`.`mytable2`, `mydb`.`mytable3` to `mydb`.`mytable1`",
"rename table `mytable` to `mytable2`, `mytable3` to `mytable1`",
"rename table mydb.mytable to mydb.mytable2, mydb.mytable3 to mydb.mytable1",
"rename table mytable to mytable2, mytable3 to mytable1",
}
table := []byte("mytable")
db := []byte("mydb")
for _, s := range cases {
m := expRenameTable.FindSubmatch([]byte(s))
mLen := len(m)
if m == nil || !bytes.Equal(m[mLen-1], table) || (len(m[mLen-2]) > 0 && !bytes.Equal(m[mLen-2], db)) {
t.Fatalf("TestRenameTableExp: case %s failed\n", s)
}
}
}
func TestDropTableExp(t *testing.T) {
cases := []string{
"drop table test1",
"DROP TABLE test1",
"DROP TABLE test1",
"DROP table IF EXISTS test.test1",
"drop table `test1`",
"DROP TABLE `test1`",
"DROP table IF EXISTS `test`.`test1`",
"DROP TABLE `test1` /* generated by server */",
"DROP table if exists test1",
"DROP table if exists `test1`",
"DROP table if exists test.test1",
"DROP table if exists `test`.test1",
"DROP table if exists `test`.`test1`",
"DROP table if exists test.`test1`",
"DROP table if exists test.`test1`",
}
table := []byte("test1")
for _, s := range cases {
m := expDropTable.FindSubmatch([]byte(s))
mLen := len(m)
if m == nil {
t.Fatalf("TestDropTableExp: case %s failed\n", s)
return
}
if mLen < 4 {
t.Fatalf("TestDropTableExp: case %s failed\n", s)
return
}
if !bytes.Equal(m[mLen-1], table) {
t.Fatalf("TestDropTableExp: case %s failed\n", s)
}
}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/juju/errors"
"github.com/siddontang/go-mysql/mysql"
)
type DumpConfig struct {
@ -23,8 +24,18 @@ type DumpConfig struct {
// Ignore table format is db.table
IgnoreTables []string `toml:"ignore_tables"`
// Dump only selected records. Quotes are mandatory
Where string `toml:"where"`
// If true, discard error msg, else, output to stderr
DiscardErr bool `toml:"discard_err"`
// Set true to skip --master-data if we have no privilege to do
// 'FLUSH TABLES WITH READ LOCK'
SkipMasterData bool `toml:"skip_master_data"`
// Set to change the default max_allowed_packet size
MaxAllowedPacketMB int `toml:"max_allowed_packet_mb"`
}
type Config struct {
@ -32,11 +43,30 @@ type Config struct {
User string `toml:"user"`
Password string `toml:"password"`
ServerID uint32 `toml:"server_id"`
Flavor string `toml:"flavor"`
DataDir string `toml:"data_dir"`
Charset string `toml:"charset"`
ServerID uint32 `toml:"server_id"`
Flavor string `toml:"flavor"`
HeartbeatPeriod time.Duration `toml:"heartbeat_period"`
ReadTimeout time.Duration `toml:"read_timeout"`
// IncludeTableRegex or ExcludeTableRegex should contain database name
// Only a table which matches IncludeTableRegex and dismatches ExcludeTableRegex will be processed
// eg, IncludeTableRegex : [".*\\.canal"], ExcludeTableRegex : ["mysql\\..*"]
// this will include all database's 'canal' table, except database 'mysql'
// Default IncludeTableRegex and ExcludeTableRegex are empty, this will include all tables
IncludeTableRegex []string `toml:"include_table_regex"`
ExcludeTableRegex []string `toml:"exclude_table_regex"`
// discard row event without table meta
DiscardNoMetaRowEvent bool `toml:"discard_no_meta_row_event"`
Dump DumpConfig `toml:"dump"`
UseDecimal bool `toml:"use_decimal"`
ParseTime bool `toml:"parse_time"`
// SemiSyncEnabled enables semi-sync or not.
SemiSyncEnabled bool `toml:"semi_sync_enabled"`
}
func NewConfigWithFile(name string) (*Config, error) {
@ -66,14 +96,14 @@ func NewDefaultConfig() *Config {
c.User = "root"
c.Password = ""
rand.Seed(time.Now().Unix())
c.ServerID = uint32(rand.Intn(1000)) + 1001
c.Charset = mysql.DEFAULT_CHARSET
c.ServerID = uint32(rand.New(rand.NewSource(time.Now().Unix())).Intn(1000)) + 1001
c.Flavor = "mysql"
c.DataDir = "./var"
c.Dump.ExecutionPath = "mysqldump"
c.Dump.DiscardErr = true
c.Dump.SkipMasterData = false
return c
}

View File

@ -1,12 +1,16 @@
package canal
import (
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/siddontang/go-mysql/dump"
"github.com/shopspring/decimal"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/schema"
)
@ -14,6 +18,7 @@ type dumpParseHandler struct {
c *Canal
name string
pos uint64
gset mysql.GTIDSet
}
func (h *dumpParseHandler) BinLog(name string, pos uint64) error {
@ -23,12 +28,18 @@ func (h *dumpParseHandler) BinLog(name string, pos uint64) error {
}
func (h *dumpParseHandler) Data(db string, table string, values []string) error {
if h.c.isClosed() {
return errCanalClosed
if err := h.c.ctx.Err(); err != nil {
return err
}
tableInfo, err := h.c.GetTable(db, table)
if err != nil {
e := errors.Cause(err)
if e == ErrExcludedTable ||
e == schema.ErrTableNotExist ||
e == schema.ErrMissingTableMeta {
return nil
}
log.Errorf("get %s.%s information err: %v", db, table, err)
return errors.Trace(err)
}
@ -38,32 +49,51 @@ func (h *dumpParseHandler) Data(db string, table string, values []string) error
for i, v := range values {
if v == "NULL" {
vs[i] = nil
} else if v == "_binary ''" {
vs[i] = []byte{}
} else if v[0] != '\'' {
if tableInfo.Columns[i].Type == schema.TYPE_NUMBER {
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
log.Errorf("parse row %v at %d error %v, skip", values, i, err)
return dump.ErrSkip
return fmt.Errorf("parse row %v at %d error %v, int expected", values, i, err)
}
vs[i] = n
} else if tableInfo.Columns[i].Type == schema.TYPE_FLOAT {
f, err := strconv.ParseFloat(v, 64)
if err != nil {
log.Errorf("parse row %v at %d error %v, skip", values, i, err)
return dump.ErrSkip
return fmt.Errorf("parse row %v at %d error %v, float expected", values, i, err)
}
vs[i] = f
} else if tableInfo.Columns[i].Type == schema.TYPE_DECIMAL {
if h.c.cfg.UseDecimal {
d, err := decimal.NewFromString(v)
if err != nil {
return fmt.Errorf("parse row %v at %d error %v, decimal expected", values, i, err)
}
vs[i] = d
} else {
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return fmt.Errorf("parse row %v at %d error %v, float expected", values, i, err)
}
vs[i] = f
}
} else if strings.HasPrefix(v, "0x") {
buf, err := hex.DecodeString(v[2:])
if err != nil {
return fmt.Errorf("parse row %v at %d error %v, hex literal expected", values, i, err)
}
vs[i] = string(buf)
} else {
log.Errorf("parse row %v error, invalid type at %d, skip", values, i)
return dump.ErrSkip
return fmt.Errorf("parse row %v error, invalid type at %d", values, i)
}
} else {
vs[i] = v[1 : len(v)-1]
}
}
events := newRowsEvent(tableInfo, InsertAction, [][]interface{}{vs})
return h.c.travelRowsEventHandler(events)
events := newRowsEvent(tableInfo, InsertAction, [][]interface{}{vs}, nil)
return h.c.eventHandler.OnRow(events)
}
func (c *Canal) AddDumpDatabases(dbs ...string) {
@ -90,10 +120,64 @@ func (c *Canal) AddDumpIgnoreTables(db string, tables ...string) {
c.dumper.AddIgnoreTables(db, tables...)
}
func (c *Canal) dump() error {
if c.dumper == nil {
return errors.New("mysqldump does not exist")
}
c.master.UpdateTimestamp(uint32(time.Now().Unix()))
h := &dumpParseHandler{c: c}
// If users call StartFromGTID with empty position to start dumping with gtid,
// we record the current gtid position before dump starts.
//
// See tryDump() to see when dump is skipped.
if c.master.GTIDSet() != nil {
gset, err := c.GetMasterGTIDSet()
if err != nil {
return errors.Trace(err)
}
h.gset = gset
}
if c.cfg.Dump.SkipMasterData {
pos, err := c.GetMasterPos()
if err != nil {
return errors.Trace(err)
}
log.Infof("skip master data, get current binlog position %v", pos)
h.name = pos.Name
h.pos = uint64(pos.Pos)
}
start := time.Now()
log.Info("try dump MySQL and parse")
if err := c.dumper.DumpAndParse(h); err != nil {
return errors.Trace(err)
}
pos := mysql.Position{Name: h.name, Pos: uint32(h.pos)}
c.master.Update(pos)
if err := c.eventHandler.OnPosSynced(pos, true); err != nil {
return errors.Trace(err)
}
var startPos fmt.Stringer = pos
if h.gset != nil {
c.master.UpdateGTIDSet(h.gset)
startPos = h.gset
}
log.Infof("dump MySQL and parse OK, use %0.2f seconds, start binlog replication at %s",
time.Now().Sub(start).Seconds(), startPos)
return nil
}
func (c *Canal) tryDump() error {
if len(c.master.Name) > 0 && c.master.Position > 0 {
pos := c.master.Position()
gset := c.master.GTIDSet()
if (len(pos.Name) > 0 && pos.Pos > 0) ||
(gset != nil && gset.String() != "") {
// we will sync with binlog name and position
log.Infof("skip dump, use last binlog replication pos (%s, %d)", c.master.Name, c.master.Position)
log.Infof("skip dump, use last binlog replication pos %s or GTID set %s", pos, gset)
return nil
}
@ -102,19 +186,5 @@ func (c *Canal) tryDump() error {
return nil
}
h := &dumpParseHandler{c: c}
start := time.Now()
log.Info("try dump MySQL and parse")
if err := c.dumper.DumpAndParse(h); err != nil {
return errors.Trace(err)
}
log.Infof("dump MySQL and parse OK, use %0.2f seconds, start binlog replication at (%s, %d)",
time.Now().Sub(start).Seconds(), h.name, h.pos)
c.master.Update(h.name, uint32(h.pos))
c.master.Save(true)
return nil
return c.dump()
}

View File

@ -1,41 +1,41 @@
package canal
import (
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/replication"
)
var (
ErrHandleInterrupted = errors.New("do handler error, interrupted")
)
type RowsEventHandler interface {
// Handle RowsEvent, if return ErrHandleInterrupted, canal will
// stop the sync
Do(e *RowsEvent) error
type EventHandler interface {
OnRotate(roateEvent *replication.RotateEvent) error
// OnTableChanged is called when the table is created, altered, renamed or dropped.
// You need to clear the associated data like cache with the table.
// It will be called before OnDDL.
OnTableChanged(schema string, table string) error
OnDDL(nextPos mysql.Position, queryEvent *replication.QueryEvent) error
OnRow(e *RowsEvent) error
OnXID(nextPos mysql.Position) error
OnGTID(gtid mysql.GTIDSet) error
// OnPosSynced Use your own way to sync position. When force is true, sync position immediately.
OnPosSynced(pos mysql.Position, force bool) error
String() string
}
func (c *Canal) RegRowsEventHandler(h RowsEventHandler) {
c.rsLock.Lock()
c.rsHandlers = append(c.rsHandlers, h)
c.rsLock.Unlock()
type DummyEventHandler struct {
}
func (c *Canal) travelRowsEventHandler(e *RowsEvent) error {
c.rsLock.Lock()
defer c.rsLock.Unlock()
var err error
for _, h := range c.rsHandlers {
if err = h.Do(e); err != nil && !mysql.ErrorEqual(err, ErrHandleInterrupted) {
log.Errorf("handle %v err: %v", h, err)
} else if mysql.ErrorEqual(err, ErrHandleInterrupted) {
log.Errorf("handle %v err, interrupted", h)
return ErrHandleInterrupted
}
}
func (h *DummyEventHandler) OnRotate(*replication.RotateEvent) error { return nil }
func (h *DummyEventHandler) OnTableChanged(schema string, table string) error { return nil }
func (h *DummyEventHandler) OnDDL(nextPos mysql.Position, queryEvent *replication.QueryEvent) error {
return nil
}
func (h *DummyEventHandler) OnRow(*RowsEvent) error { return nil }
func (h *DummyEventHandler) OnXID(mysql.Position) error { return nil }
func (h *DummyEventHandler) OnGTID(mysql.GTIDSet) error { return nil }
func (h *DummyEventHandler) OnPosSynced(mysql.Position, bool) error { return nil }
func (h *DummyEventHandler) String() string { return "DummyEventHandler" }
// `SetEventHandler` registers the sync handler, you must register your
// own handler before starting Canal.
func (c *Canal) SetEventHandler(h EventHandler) {
c.eventHandler = h
}

View File

@ -1,89 +1,66 @@
package canal
import (
"bytes"
"os"
"sync"
"time"
"github.com/BurntSushi/toml"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go/ioutil2"
)
type masterInfo struct {
Addr string `toml:"addr"`
Name string `toml:"bin_name"`
Position uint32 `toml:"bin_pos"`
sync.RWMutex
name string
pos mysql.Position
l sync.Mutex
gset mysql.GTIDSet
lastSaveTime time.Time
timestamp uint32
}
func loadMasterInfo(name string) (*masterInfo, error) {
var m masterInfo
func (m *masterInfo) Update(pos mysql.Position) {
log.Debugf("update master position %s", pos)
m.name = name
f, err := os.Open(name)
if err != nil && !os.IsNotExist(errors.Cause(err)) {
return nil, errors.Trace(err)
} else if os.IsNotExist(errors.Cause(err)) {
return &m, nil
}
defer f.Close()
_, err = toml.DecodeReader(f, &m)
return &m, err
m.Lock()
m.pos = pos
m.Unlock()
}
func (m *masterInfo) Save(force bool) error {
m.l.Lock()
defer m.l.Unlock()
func (m *masterInfo) UpdateTimestamp(ts uint32) {
log.Debugf("update master timestamp %s", ts)
n := time.Now()
if !force && n.Sub(m.lastSaveTime) < time.Second {
m.Lock()
m.timestamp = ts
m.Unlock()
}
func (m *masterInfo) UpdateGTIDSet(gset mysql.GTIDSet) {
log.Debugf("update master gtid set %s", gset)
m.Lock()
m.gset = gset
m.Unlock()
}
func (m *masterInfo) Position() mysql.Position {
m.RLock()
defer m.RUnlock()
return m.pos
}
func (m *masterInfo) Timestamp() uint32 {
m.RLock()
defer m.RUnlock()
return m.timestamp
}
func (m *masterInfo) GTIDSet() mysql.GTIDSet {
m.RLock()
defer m.RUnlock()
if m.gset == nil {
return nil
}
var buf bytes.Buffer
e := toml.NewEncoder(&buf)
e.Encode(m)
var err error
if err = ioutil2.WriteFileAtomic(m.name, buf.Bytes(), 0644); err != nil {
log.Errorf("canal save master info to file %s err %v", m.name, err)
}
m.lastSaveTime = n
return errors.Trace(err)
}
func (m *masterInfo) Update(name string, pos uint32) {
m.l.Lock()
m.Name = name
m.Position = pos
m.l.Unlock()
}
func (m *masterInfo) Pos() mysql.Position {
var pos mysql.Position
m.l.Lock()
pos.Name = m.Name
pos.Pos = m.Position
m.l.Unlock()
return pos
}
func (m *masterInfo) Close() {
m.Save(true)
return m.gset.Clone()
}

View File

@ -3,16 +3,18 @@ package canal
import (
"fmt"
"github.com/juju/errors"
"github.com/siddontang/go-mysql/replication"
"github.com/siddontang/go-mysql/schema"
)
// The action name for sync.
const (
UpdateAction = "update"
InsertAction = "insert"
DeleteAction = "delete"
)
// RowsEvent is the event for row replication.
type RowsEvent struct {
Table *schema.Table
Action string
@ -22,35 +24,49 @@ type RowsEvent struct {
// Two rows for one event, format is [before update row, after update row]
// for update v0, only one row for a event, and we don't support this version.
Rows [][]interface{}
// Header can be used to inspect the event
Header *replication.EventHeader
}
func newRowsEvent(table *schema.Table, action string, rows [][]interface{}) *RowsEvent {
func newRowsEvent(table *schema.Table, action string, rows [][]interface{}, header *replication.EventHeader) *RowsEvent {
e := new(RowsEvent)
e.Table = table
e.Action = action
e.Rows = rows
e.Header = header
e.handleUnsigned()
return e
}
// Get primary keys in one row for a table, a table may use multi fields as the PK
func GetPKValues(table *schema.Table, row []interface{}) ([]interface{}, error) {
indexes := table.PKColumns
if len(indexes) == 0 {
return nil, errors.Errorf("table %s has no PK", table)
} else if len(table.Columns) != len(row) {
return nil, errors.Errorf("table %s has %d columns, but row data %v len is %d", table,
len(table.Columns), row, len(row))
func (r *RowsEvent) handleUnsigned() {
// Handle Unsigned Columns here, for binlog replication, we can't know the integer is unsigned or not,
// so we use int type but this may cause overflow outside sometimes, so we must convert to the really .
// unsigned type
if len(r.Table.UnsignedColumns) == 0 {
return
}
values := make([]interface{}, 0, len(indexes))
for _, index := range indexes {
values = append(values, row[index])
for i := 0; i < len(r.Rows); i++ {
for _, index := range r.Table.UnsignedColumns {
switch t := r.Rows[i][index].(type) {
case int8:
r.Rows[i][index] = uint8(t)
case int16:
r.Rows[i][index] = uint16(t)
case int32:
r.Rows[i][index] = uint32(t)
case int64:
r.Rows[i][index] = uint64(t)
case int:
r.Rows[i][index] = uint(t)
default:
// nothing to do
}
}
}
return values, nil
}
// String implements fmt.Stringer interface.

View File

@ -1,49 +1,68 @@
package canal
import (
"fmt"
"regexp"
"time"
"golang.org/x/net/context"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/satori/go.uuid"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/replication"
"github.com/siddontang/go-mysql/schema"
)
func (c *Canal) startSyncBinlog() error {
pos := mysql.Position{c.master.Name, c.master.Position}
var (
expCreateTable = regexp.MustCompile("(?i)^CREATE\\sTABLE(\\sIF\\sNOT\\sEXISTS)?\\s`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}\\s.*")
expAlterTable = regexp.MustCompile("(?i)^ALTER\\sTABLE\\s.*?`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}\\s.*")
expRenameTable = regexp.MustCompile("(?i)^RENAME\\sTABLE\\s.*?`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}\\s{1,}TO\\s.*?")
expDropTable = regexp.MustCompile("(?i)^DROP\\sTABLE(\\sIF\\sEXISTS){0,1}\\s`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}(?:$|\\s)")
expTruncateTable = regexp.MustCompile("(?i)^TRUNCATE\\s+(?:TABLE\\s+)?(?:`?([^`\\s]+)`?\\.`?)?([^`\\s]+)`?")
)
log.Infof("start sync binlog at %v", pos)
func (c *Canal) startSyncer() (*replication.BinlogStreamer, error) {
gset := c.master.GTIDSet()
if gset == nil {
pos := c.master.Position()
s, err := c.syncer.StartSync(pos)
if err != nil {
return nil, errors.Errorf("start sync replication at binlog %v error %v", pos, err)
}
log.Infof("start sync binlog at binlog file %v", pos)
return s, nil
} else {
s, err := c.syncer.StartSyncGTID(gset)
if err != nil {
return nil, errors.Errorf("start sync replication at GTID set %v error %v", gset, err)
}
log.Infof("start sync binlog at GTID set %v", gset)
return s, nil
}
}
s, err := c.syncer.StartSync(pos)
func (c *Canal) runSyncBinlog() error {
s, err := c.startSyncer()
if err != nil {
return errors.Errorf("start sync replication at %v error %v", pos, err)
return err
}
timeout := time.Second
forceSavePos := false
savePos := false
force := false
for {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ev, err := s.GetEvent(ctx)
cancel()
if err == context.DeadlineExceeded {
timeout = 2 * timeout
continue
}
ev, err := s.GetEvent(c.ctx)
if err != nil {
return errors.Trace(err)
}
savePos = false
force = false
pos := c.master.Position()
timeout = time.Second
curPos := pos.Pos
//next binlog pos
pos.Pos = ev.Header.LogPos
forceSavePos = false
// We only save position with RotateEvent and XIDEvent.
// For RowsEvent, we can't save the position until meeting XIDEvent
// which tells the whole transaction is over.
@ -52,24 +71,105 @@ func (c *Canal) startSyncBinlog() error {
case *replication.RotateEvent:
pos.Name = string(e.NextLogName)
pos.Pos = uint32(e.Position)
// r.ev <- pos
forceSavePos = true
log.Infof("rotate binlog to %v", pos)
log.Infof("rotate binlog to %s", pos)
savePos = true
force = true
if err = c.eventHandler.OnRotate(e); err != nil {
return errors.Trace(err)
}
case *replication.RowsEvent:
// we only focus row based event
if err = c.handleRowsEvent(ev); err != nil {
log.Errorf("handle rows event error %v", err)
return errors.Trace(err)
err = c.handleRowsEvent(ev)
if err != nil {
e := errors.Cause(err)
// if error is not ErrExcludedTable or ErrTableNotExist or ErrMissingTableMeta, stop canal
if e != ErrExcludedTable &&
e != schema.ErrTableNotExist &&
e != schema.ErrMissingTableMeta {
log.Errorf("handle rows event at (%s, %d) error %v", pos.Name, curPos, err)
return errors.Trace(err)
}
}
continue
case *replication.XIDEvent:
if e.GSet != nil {
c.master.UpdateGTIDSet(e.GSet)
}
savePos = true
// try to save the position later
if err := c.eventHandler.OnXID(pos); err != nil {
return errors.Trace(err)
}
case *replication.MariadbGTIDEvent:
// try to save the GTID later
gtid, err := mysql.ParseMariadbGTIDSet(e.GTID.String())
if err != nil {
return errors.Trace(err)
}
if err := c.eventHandler.OnGTID(gtid); err != nil {
return errors.Trace(err)
}
case *replication.GTIDEvent:
u, _ := uuid.FromBytes(e.SID)
gtid, err := mysql.ParseMysqlGTIDSet(fmt.Sprintf("%s:%d", u.String(), e.GNO))
if err != nil {
return errors.Trace(err)
}
if err := c.eventHandler.OnGTID(gtid); err != nil {
return errors.Trace(err)
}
case *replication.QueryEvent:
if e.GSet != nil {
c.master.UpdateGTIDSet(e.GSet)
}
var (
mb [][]byte
db []byte
table []byte
)
regexps := []regexp.Regexp{*expCreateTable, *expAlterTable, *expRenameTable, *expDropTable, *expTruncateTable}
for _, reg := range regexps {
mb = reg.FindSubmatch(e.Query)
if len(mb) != 0 {
break
}
}
mbLen := len(mb)
if mbLen == 0 {
continue
}
// the first last is table name, the second last is database name(if exists)
if len(mb[mbLen-2]) == 0 {
db = e.Schema
} else {
db = mb[mbLen-2]
}
table = mb[mbLen-1]
savePos = true
force = true
c.ClearTableCache(db, table)
log.Infof("table structure changed, clear table cache: %s.%s\n", db, table)
if err = c.eventHandler.OnTableChanged(string(db), string(table)); err != nil && errors.Cause(err) != schema.ErrTableNotExist {
return errors.Trace(err)
}
// Now we only handle Table Changed DDL, maybe we will support more later.
if err = c.eventHandler.OnDDL(pos, e); err != nil {
return errors.Trace(err)
}
default:
continue
}
c.master.Update(pos.Name, pos.Pos)
c.master.Save(forceSavePos)
if savePos {
c.master.Update(pos)
c.master.UpdateTimestamp(ev.Header.Timestamp)
if err := c.eventHandler.OnPosSynced(pos, force); err != nil {
return errors.Trace(err)
}
}
}
return nil
@ -84,7 +184,7 @@ func (c *Canal) handleRowsEvent(e *replication.BinlogEvent) error {
t, err := c.GetTable(schema, table)
if err != nil {
return errors.Trace(err)
return err
}
var action string
switch e.Header.EventType {
@ -97,25 +197,31 @@ func (c *Canal) handleRowsEvent(e *replication.BinlogEvent) error {
default:
return errors.Errorf("%s not supported now", e.Header.EventType)
}
events := newRowsEvent(t, action, ev.Rows)
return c.travelRowsEventHandler(events)
events := newRowsEvent(t, action, ev.Rows, e.Header)
return c.eventHandler.OnRow(events)
}
func (c *Canal) WaitUntilPos(pos mysql.Position, timeout int) error {
if timeout <= 0 {
timeout = 60
}
func (c *Canal) FlushBinlog() error {
_, err := c.Execute("FLUSH BINARY LOGS")
return errors.Trace(err)
}
timer := time.NewTimer(time.Duration(timeout) * time.Second)
func (c *Canal) WaitUntilPos(pos mysql.Position, timeout time.Duration) error {
timer := time.NewTimer(timeout)
for {
select {
case <-timer.C:
return errors.Errorf("wait position %v err", pos)
return errors.Errorf("wait position %v too long > %s", pos, timeout)
default:
curpos := c.master.Pos()
if curpos.Compare(pos) >= 0 {
err := c.FlushBinlog()
if err != nil {
return errors.Trace(err)
}
curPos := c.master.Position()
if curPos.Compare(pos) >= 0 {
return nil
} else {
log.Debugf("master pos is %v, wait catching %v", curPos, pos)
time.Sleep(100 * time.Millisecond)
}
}
@ -124,14 +230,46 @@ func (c *Canal) WaitUntilPos(pos mysql.Position, timeout int) error {
return nil
}
func (c *Canal) CatchMasterPos(timeout int) error {
func (c *Canal) GetMasterPos() (mysql.Position, error) {
rr, err := c.Execute("SHOW MASTER STATUS")
if err != nil {
return errors.Trace(err)
return mysql.Position{}, errors.Trace(err)
}
name, _ := rr.GetString(0, 0)
pos, _ := rr.GetInt(0, 1)
return c.WaitUntilPos(mysql.Position{name, uint32(pos)}, timeout)
return mysql.Position{Name: name, Pos: uint32(pos)}, nil
}
func (c *Canal) GetMasterGTIDSet() (mysql.GTIDSet, error) {
query := ""
switch c.cfg.Flavor {
case mysql.MariaDBFlavor:
query = "SELECT @@GLOBAL.gtid_current_pos"
default:
query = "SELECT @@GLOBAL.GTID_EXECUTED"
}
rr, err := c.Execute(query)
if err != nil {
return nil, errors.Trace(err)
}
gx, err := rr.GetString(0, 0)
if err != nil {
return nil, errors.Trace(err)
}
gset, err := mysql.ParseGTIDSet(c.cfg.Flavor, gx)
if err != nil {
return nil, errors.Trace(err)
}
return gset, nil
}
func (c *Canal) CatchMasterPos(timeout time.Duration) error {
pos, err := c.GetMasterPos()
if err != nil {
return errors.Trace(err)
}
return c.WaitUntilPos(pos, timeout)
}

View File

@ -4,12 +4,29 @@ import (
"bytes"
"crypto/tls"
"encoding/binary"
"fmt"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/packet"
)
const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
// defines the supported auth plugins
var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD}
// helper function to determine what auth methods are allowed by this client
func authPluginAllowed(pluginName string) bool {
for _, p := range supportedAuthPlugins {
if pluginName == p {
return true
}
}
return false
}
// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (c *Conn) readInitialHandshake() error {
data, err := c.ReadPacket()
if err != nil {
@ -24,39 +41,44 @@ func (c *Conn) readInitialHandshake() error {
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
}
//skip mysql version
//mysql version end with 0x00
// skip mysql version
// mysql version end with 0x00
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
//connection id length is 4
// connection id length is 4
c.connectionID = uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
pos += 4
c.salt = []byte{}
c.salt = append(c.salt, data[pos:pos+8]...)
//skip filter
// skip filter
pos += 8 + 1
//capability lower 2 bytes
// capability lower 2 bytes
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
// check protocol
if c.capability&CLIENT_PROTOCOL_41 == 0 {
return errors.New("the MySQL server can not support protocol 41 and above required by the client")
}
if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil {
return errors.New("the MySQL Server does not support TLS required by the client")
}
pos += 2
if len(data) > pos {
//skip server charset
// skip server charset
//c.charset = data[pos]
pos += 1
c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2
// capability flags (upper 2 bytes)
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
pos += 2
//skip auth data len or [00]
//skip reserved (all [00])
// skip auth data len or [00]
// skip reserved (all [00])
pos += 10 + 1
// The documentation is ambiguous about the length.
@ -64,78 +86,131 @@ func (c *Conn) readInitialHandshake() error {
// mysql-proxy also use 12
// which is not documented but seems to work.
c.salt = append(c.salt, data[pos:pos+12]...)
pos += 13
// auth plugin
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
c.authPluginName = string(data[pos : pos+end])
} else {
c.authPluginName = string(data[pos:])
}
}
// if server gives no default auth plugin name, use a client default
if c.authPluginName == "" {
c.authPluginName = defaultAuthPluginName
}
return nil
}
// generate auth response data according to auth plugin
//
// NOTE: the returned boolean value indicates whether to add a \NUL to the end of data.
// it is quite tricky because MySQl server expects different formats of responses in different auth situations.
// here the \NUL needs to be added when sending back the empty password or cleartext password in 'sha256_password'
// authentication.
func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) {
// password hashing
switch c.authPluginName {
case AUTH_NATIVE_PASSWORD:
return CalcPassword(authData[:20], []byte(c.password)), false, nil
case AUTH_CACHING_SHA2_PASSWORD:
return CalcCachingSha2Password(authData, c.password), false, nil
case AUTH_SHA256_PASSWORD:
if len(c.password) == 0 {
return nil, true, nil
}
if c.tlsConfig != nil || c.proto == "unix" {
// write cleartext auth packet
// see: https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html
return []byte(c.password), true, nil
} else {
// request public key from server
// see: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html
return []byte{1}, false, nil
}
default:
// not reachable
return nil, false, fmt.Errorf("auth plugin '%s' is not supported", c.authPluginName)
}
}
// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (c *Conn) writeAuthHandshake() error {
if !authPluginAllowed(c.authPluginName) {
return fmt.Errorf("unknow auth plugin name '%s'", c.authPluginName)
}
// Adjust client capability flags based on server support
capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH | c.capability&CLIENT_LONG_FLAG
// To enable TLS / SSL
if c.TLSConfig != nil {
capability |= CLIENT_PLUGIN_AUTH
if c.tlsConfig != nil {
capability |= CLIENT_SSL
}
capability &= c.capability
auth, addNull, err := c.genAuthResponse(c.salt)
if err != nil {
return err
}
// encode length of the auth plugin data
// here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte
// see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer
var authRespLEIBuf [9]byte
authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
if len(authRespLEI) > 1 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
}
//packet length
//capbility 4
//capability 4
//max-packet size 4
//charset 1
//reserved all[0] 23
length := 4 + 4 + 1 + 23
//username
length += len(c.user) + 1
//we only support secure connection
auth := CalcPassword(c.salt, []byte(c.password))
length += 1 + len(auth)
//auth
//mysql_native_password + null-terminated
length := 4 + 4 + 1 + 23 + len(c.user) + 1 + len(authRespLEI) + len(auth) + 21 + 1
if addNull {
length++
}
// db name
if len(c.db) > 0 {
capability |= CLIENT_CONNECT_WITH_DB
length += len(c.db) + 1
}
// mysql_native_password + null-terminated
length += 21 + 1
c.capability = capability
data := make([]byte, length+4)
//capability [32 bit]
// capability [32 bit]
data[4] = byte(capability)
data[5] = byte(capability >> 8)
data[6] = byte(capability >> 16)
data[7] = byte(capability >> 24)
//MaxPacketSize [32 bit] (none)
//data[8] = 0x00
//data[9] = 0x00
//data[10] = 0x00
//data[11] = 0x00
// MaxPacketSize [32 bit] (none)
data[8] = 0x00
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00
//Charset [1 byte]
//use default collation id 33 here, is utf-8
// Charset [1 byte]
// use default collation id 33 here, is utf-8
data[12] = byte(DEFAULT_COLLATION_ID)
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if c.TLSConfig != nil {
if c.tlsConfig != nil {
// Send TLS / SSL request packet
if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil {
return err
}
// Switch to TLS
tlsConn := tls.Client(c.Conn.Conn, c.TLSConfig)
tlsConn := tls.Client(c.Conn.Conn, c.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return err
}
@ -145,10 +220,13 @@ func (c *Conn) writeAuthHandshake() error {
c.Sequence = currentSequence
}
//Filler [23 bytes] (all 0x00)
pos := 13 + 23
// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
data[pos] = 0
}
//User [null terminated string]
// User [null terminated string]
if len(c.user) > 0 {
pos += copy(data[pos:], c.user)
}
@ -156,8 +234,12 @@ func (c *Conn) writeAuthHandshake() error {
pos++
// auth [length encoded integer]
data[pos] = byte(len(auth))
pos += 1 + copy(data[pos+1:], auth)
pos += copy(data[pos:], authRespLEI)
pos += copy(data[pos:], auth)
if addNull {
data[pos] = 0x00
pos++
}
// db [null terminated string]
if len(c.db) > 0 {
@ -167,7 +249,7 @@ func (c *Conn) writeAuthHandshake() error {
}
// Assume native client during response
pos += copy(data[pos:], "mysql_native_password")
pos += copy(data[pos:], c.authPluginName)
data[pos] = 0x00
return c.WritePacket(data)

View File

@ -1,41 +1,56 @@
package client
import (
"crypto/tls"
"flag"
"fmt"
"strings"
"testing"
"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/siddontang/go-mysql/test_util/test_keys"
"github.com/siddontang/go-mysql/mysql"
)
var testHost = flag.String("host", "127.0.0.1", "MySQL server host")
var testPort = flag.Int("port", 3306, "MySQL server port")
// We cover the whole range of MySQL server versions using docker-compose to bind them to different ports for testing.
// MySQL is constantly updating auth plugin to make it secure:
// starting from MySQL 8.0.4, a new auth plugin is introduced, causing plain password auth to fail with error:
// ERROR 1251 (08004): Client does not support authentication protocol requested by server; consider upgrading MySQL client
// Hint: use docker-compose to start corresponding MySQL docker containers and add the their ports here
var testPort = flag.String("port", "3306", "MySQL server port") // choose one or more form 5561,5641,3306,5722,8003,8012,8013, e.g. '3306,5722,8003'
var testUser = flag.String("user", "root", "MySQL user")
var testPassword = flag.String("pass", "", "MySQL password")
var testDB = flag.String("db", "test", "MySQL test database")
func Test(t *testing.T) {
segs := strings.Split(*testPort, ",")
for _, seg := range segs {
Suite(&clientTestSuite{port: seg})
}
TestingT(t)
}
type clientTestSuite struct {
c *Conn
c *Conn
port string
}
var _ = Suite(&clientTestSuite{})
func (s *clientTestSuite) SetUpSuite(c *C) {
var err error
addr := fmt.Sprintf("%s:%d", *testHost, *testPort)
s.c, err = Connect(addr, *testUser, *testPassword, *testDB)
addr := fmt.Sprintf("%s:%s", *testHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "")
if err != nil {
c.Fatal(err)
}
_, err = s.c.Execute("CREATE DATABASE IF NOT EXISTS " + *testDB)
c.Assert(err, IsNil)
_, err = s.c.Execute("USE " + *testDB)
c.Assert(err, IsNil)
s.testConn_CreateTable(c)
s.testStmt_CreateTable(c)
}
@ -78,12 +93,15 @@ func (s *clientTestSuite) TestConn_Ping(c *C) {
c.Assert(err, IsNil)
}
func (s *clientTestSuite) TestConn_TLS(c *C) {
// NOTE for MySQL 5.5 and 5.6, server side has to config SSL to pass the TLS test, otherwise, it will throw error that
// MySQL server does not support TLS required by the client. However, for MySQL 5.7 and above, auto generated certificates
// are used by default so that manual config is no longer necessary.
func (s *clientTestSuite) TestConn_TLS_Verify(c *C) {
// Verify that the provided tls.Config is used when attempting to connect to mysql.
// An empty tls.Config will result in a connection error.
addr := fmt.Sprintf("%s:%d", *testHost, *testPort)
addr := fmt.Sprintf("%s:%s", *testHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
c.TLSConfig = &tls.Config{}
c.UseSSL(false)
})
if err == nil {
c.Fatal("expected error")
@ -91,7 +109,34 @@ func (s *clientTestSuite) TestConn_TLS(c *C) {
expected := "either ServerName or InsecureSkipVerify must be specified in the tls.Config"
if !strings.Contains(err.Error(), expected) {
c.Fatal("expected '%s' to contain '%s'", err.Error(), expected)
c.Fatalf("expected '%s' to contain '%s'", err.Error(), expected)
}
}
func (s *clientTestSuite) TestConn_TLS_Skip_Verify(c *C) {
// An empty tls.Config will result in a connection error but we can configure to skip it.
addr := fmt.Sprintf("%s:%s", *testHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
c.UseSSL(true)
})
c.Assert(err, Equals, nil)
}
func (s *clientTestSuite) TestConn_TLS_Certificate(c *C) {
// This test uses the TLS suite in 'go-mysql/docker/resources'. The certificates are not valid for any names.
// And if server uses auto-generated certificates, it will be an error like:
// "x509: certificate is valid for MySQL_Server_8.0.12_Auto_Generated_Server_Certificate, not not-a-valid-name"
tlsConfig := NewClientTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, false, "not-a-valid-name")
addr := fmt.Sprintf("%s:%s", *testHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
c.SetTLSConfig(tlsConfig)
})
if err == nil {
c.Fatal("expected error")
}
if !strings.Contains(errors.Details(err), "certificate is not valid for any names") &&
!strings.Contains(errors.Details(err), "certificate is valid for") {
c.Fatalf("expected errors for server name verification, but got unknown error: %s", errors.Details(err))
}
}
@ -349,4 +394,4 @@ func (s *clientTestSuite) TestStmt_Trans(c *C) {
str, _ = r.GetString(0, 0)
c.Assert(str, Equals, `abc`)
}
}

View File

@ -18,7 +18,8 @@ type Conn struct {
user string
password string
db string
TLSConfig *tls.Config
tlsConfig *tls.Config
proto string
capability uint32
@ -26,7 +27,8 @@ type Conn struct {
charset string
salt []byte
salt []byte
authPluginName string
connectionID uint32
}
@ -56,6 +58,7 @@ func Connect(addr string, user string, password string, dbName string, options .
c.user = user
c.password = password
c.db = dbName
c.proto = proto
//use default charset here, utf-8
c.charset = DEFAULT_CHARSET
@ -85,7 +88,7 @@ func (c *Conn) handshake() error {
return errors.Trace(err)
}
if _, err := c.readOK(); err != nil {
if err := c.handleAuthResult(); err != nil {
c.Close()
return errors.Trace(err)
}
@ -109,6 +112,18 @@ func (c *Conn) Ping() error {
return nil
}
// use default SSL
// pass to options when connect
func (c *Conn) UseSSL(insecureSkipVerify bool) {
c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify}
}
// use user-specified TLS config
// pass to options when connect
func (c *Conn) SetTLSConfig(config *tls.Config) {
c.tlsConfig = config
}
func (c *Conn) UseDB(dbName string) error {
if c.db == dbName {
return nil

View File

@ -1,8 +1,14 @@
package client
import "C"
import (
"encoding/binary"
"bytes"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go/hack"
@ -32,7 +38,7 @@ func (c *Conn) isEOFPacket(data []byte) bool {
func (c *Conn) handleOKPacket(data []byte) (*Result, error) {
var n int
var pos int = 1
var pos = 1
r := new(Result)
@ -64,7 +70,7 @@ func (c *Conn) handleOKPacket(data []byte) (*Result, error) {
func (c *Conn) handleErrorPacket(data []byte) error {
e := new(MyError)
var pos int = 1
var pos = 1
e.Code = binary.LittleEndian.Uint16(data[pos:])
pos += 2
@ -81,6 +87,116 @@ func (c *Conn) handleErrorPacket(data []byte) error {
return e
}
func (c *Conn) handleAuthResult() error {
data, switchToPlugin, err := c.readAuthResult()
if err != nil {
return err
}
// handle auth switch, only support 'sha256_password', and 'caching_sha2_password'
if switchToPlugin != "" {
//fmt.Printf("now switching auth plugin to '%s'\n", switchToPlugin)
if data == nil {
data = c.salt
} else {
copy(c.salt, data)
}
c.authPluginName = switchToPlugin
auth, addNull, err := c.genAuthResponse(data)
if err = c.WriteAuthSwitchPacket(auth, addNull); err != nil {
return err
}
// Read Result Packet
data, switchToPlugin, err = c.readAuthResult()
if err != nil {
return err
}
// Do not allow to change the auth plugin more than once
if switchToPlugin != "" {
return errors.Errorf("can not switch auth plugin more than once")
}
}
// handle caching_sha2_password
if c.authPluginName == AUTH_CACHING_SHA2_PASSWORD {
if data == nil {
return nil // auth already succeeded
}
if data[0] == CACHE_SHA2_FAST_AUTH {
if _, err = c.readOK(); err == nil {
return nil // auth successful
}
} else if data[0] == CACHE_SHA2_FULL_AUTH {
// need full authentication
if c.tlsConfig != nil || c.proto == "unix" {
if err = c.WriteClearAuthPacket(c.password); err != nil {
return err
}
} else {
if err = c.WritePublicKeyAuthPacket(c.password, c.salt); err != nil {
return err
}
}
} else {
errors.Errorf("invalid packet")
}
} else if c.authPluginName == AUTH_SHA256_PASSWORD {
if len(data) == 0 {
return nil // auth already succeeded
}
block, _ := pem.Decode(data)
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}
// send encrypted password
err = c.WriteEncryptedPassword(c.password, c.salt, pub.(*rsa.PublicKey))
if err != nil {
return err
}
_, err = c.readOK()
return err
}
return nil
}
func (c *Conn) readAuthResult() ([]byte, string, error) {
data, err := c.ReadPacket()
if err != nil {
return nil, "", err
}
// see: https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
// packet indicator
switch data[0] {
case OK_HEADER:
_, err := c.handleOKPacket(data)
return nil, "", err
case MORE_DATE_HEADER:
return data[1:], "", err
case EOF_HEADER:
// server wants to switch auth
if len(data) < 1 {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
return nil, AUTH_MYSQL_OLD_PASSWORD, nil
}
pluginEndIndex := bytes.IndexByte(data, 0x00)
if pluginEndIndex < 0 {
return nil, "", errors.New("invalid packet")
}
plugin := string(data[1:pluginEndIndex])
authData := data[pluginEndIndex+1:]
return authData, plugin, nil
default: // Error otherwise
return nil, "", c.handleErrorPacket(data)
}
}
func (c *Conn) readOK() (*Result, error) {
data, err := c.ReadPacket()
if err != nil {

View File

@ -23,6 +23,6 @@ func main() {
err := p.ParseFile(*name, *offset, f)
if err != nil {
println(err)
println(err.Error())
}
}

View File

@ -7,8 +7,10 @@ import (
"os/signal"
"strings"
"syscall"
"time"
"github.com/siddontang/go-mysql/canal"
"github.com/siddontang/go-mysql/mysql"
)
var host = flag.String("host", "127.0.0.1", "MySQL host")
@ -18,8 +20,6 @@ var password = flag.String("password", "", "MySQL password")
var flavor = flag.String("flavor", "mysql", "Flavor: mysql or mariadb")
var dataDir = flag.String("data-dir", "./var", "Path to store data, like master.info")
var serverID = flag.Int("server-id", 101, "Unique Server ID")
var mysqldump = flag.String("mysqldump", "mysqldump", "mysqldump execution path")
@ -28,6 +28,12 @@ var tables = flag.String("tables", "", "dump tables, seperated by comma, will ov
var tableDB = flag.String("table_db", "test", "database for dump tables")
var ignoreTables = flag.String("ignore_tables", "", "ignore tables, must be database.table format, separated by comma")
var startName = flag.String("bin_name", "", "start sync from binlog name")
var startPos = flag.Uint("bin_pos", 0, "start sync from binlog position of")
var heartbeatPeriod = flag.Duration("heartbeat", 60*time.Second, "master heartbeat period")
var readTimeout = flag.Duration("read_timeout", 90*time.Second, "connection read timeout")
func main() {
flag.Parse()
@ -36,8 +42,10 @@ func main() {
cfg.User = *user
cfg.Password = *password
cfg.Flavor = *flavor
cfg.DataDir = *dataDir
cfg.UseDecimal = true
cfg.ReadTimeout = *readTimeout
cfg.HeartbeatPeriod = *heartbeatPeriod
cfg.ServerID = uint32(*serverID)
cfg.Dump.ExecutionPath = *mysqldump
cfg.Dump.DiscardErr = false
@ -65,14 +73,20 @@ func main() {
c.AddDumpDatabases(subs...)
}
c.RegRowsEventHandler(&handler{})
c.SetEventHandler(&handler{})
err = c.Start()
if err != nil {
fmt.Printf("start canal err %V", err)
os.Exit(1)
startPos := mysql.Position{
Name: *startName,
Pos: uint32(*startPos),
}
go func() {
err = c.RunFrom(startPos)
if err != nil {
fmt.Printf("start canal err %v", err)
}
}()
sc := make(chan os.Signal, 1)
signal.Notify(sc,
os.Kill,
@ -88,9 +102,10 @@ func main() {
}
type handler struct {
canal.DummyEventHandler
}
func (h *handler) Do(e *canal.RowsEvent) error {
func (h *handler) OnRow(e *canal.RowsEvent) error {
fmt.Printf("%v\n", e)
return nil

View File

@ -4,12 +4,11 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"golang.org/x/net/context"
"github.com/juju/errors"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/replication"
@ -43,11 +42,12 @@ func main() {
Password: *password,
RawModeEnabled: *rawMode,
SemiSyncEnabled: *semiSync,
UseDecimal: true,
}
b := replication.NewBinlogSyncer(cfg)
pos := mysql.Position{*file, uint32(*pos)}
pos := mysql.Position{Name: *file, Pos: uint32(*pos)}
if len(*backupPath) > 0 {
// Backup will always use RawMode.
err := b.StartBackup(*backupPath, pos, 0)
@ -65,6 +65,11 @@ func main() {
for {
e, err := s.GetEvent(context.Background())
if err != nil {
// Try to output all left events
events := s.DumpEvents()
for _, e := range events {
e.Dump(os.Stdout)
}
fmt.Printf("Get event error: %v\n", errors.ErrorStack(err))
return
}

View File

@ -11,6 +11,11 @@ import (
// Use docker mysql to test, mysql is 3306
var testHost = flag.String("host", "127.0.0.1", "MySQL master host")
// possible choices for different MySQL versions are: 5561,5641,3306,5722,8003,8012
var testPort = flag.Int("port", 3306, "MySQL server port")
var testUser = flag.String("user", "root", "MySQL user")
var testPassword = flag.String("pass", "", "MySQL password")
var testDB = flag.String("db", "test", "MySQL test database")
func TestDriver(t *testing.T) {
TestingT(t)
@ -23,7 +28,8 @@ type testDriverSuite struct {
var _ = Suite(&testDriverSuite{})
func (s *testDriverSuite) SetUpSuite(c *C) {
dsn := fmt.Sprintf("root@%s:3306?test", *testHost)
addr := fmt.Sprintf("%s:%d", *testHost, *testPort)
dsn := fmt.Sprintf("%s:%s@%s?%s", *testUser, *testPassword, addr, *testDB)
var err error
s.db, err = sqlx.Open("mysql", dsn)

View File

@ -20,7 +20,8 @@ type driver struct {
// DSN user:password@addr[?db]
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
seps := strings.Split(dsn, "@")
lastIndex := strings.LastIndex(dsn, "@")
seps := []string{dsn[:lastIndex], dsn[lastIndex+1:]}
if len(seps) != 2 {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
}

View File

@ -8,6 +8,8 @@ import (
"strings"
"github.com/juju/errors"
"github.com/siddontang/go-log/log"
. "github.com/siddontang/go-mysql/mysql"
)
// Unlick mysqldump, Dumper is designed for parsing and syning data easily.
@ -25,9 +27,16 @@ type Dumper struct {
Databases []string
Where string
Charset string
IgnoreTables map[string][]string
ErrOut io.Writer
masterDataSkipped bool
maxAllowedPacket int
hexBlob bool
}
func NewDumper(executionPath string, addr string, user string, password string) (*Dumper, error) {
@ -47,17 +56,40 @@ func NewDumper(executionPath string, addr string, user string, password string)
d.Password = password
d.Tables = make([]string, 0, 16)
d.Databases = make([]string, 0, 16)
d.Charset = DEFAULT_CHARSET
d.IgnoreTables = make(map[string][]string)
d.masterDataSkipped = false
d.ErrOut = os.Stderr
return d, nil
}
func (d *Dumper) SetCharset(charset string) {
d.Charset = charset
}
func (d *Dumper) SetWhere(where string) {
d.Where = where
}
func (d *Dumper) SetErrOut(o io.Writer) {
d.ErrOut = o
}
// In some cloud MySQL, we have no privilege to use `--master-data`.
func (d *Dumper) SkipMasterData(v bool) {
d.masterDataSkipped = v
}
func (d *Dumper) SetMaxAllowedPacket(i int) {
d.maxAllowedPacket = i
}
func (d *Dumper) SetHexBlob(v bool) {
d.hexBlob = v
}
func (d *Dumper) AddDatabases(dbs ...string) {
d.Databases = append(d.Databases, dbs...)
}
@ -82,22 +114,35 @@ func (d *Dumper) Reset() {
d.TableDB = ""
d.IgnoreTables = make(map[string][]string)
d.Databases = d.Databases[0:0]
d.Where = ""
}
func (d *Dumper) Dump(w io.Writer) error {
args := make([]string, 0, 16)
// Common args
seps := strings.Split(d.Addr, ":")
args = append(args, fmt.Sprintf("--host=%s", seps[0]))
if len(seps) > 1 {
args = append(args, fmt.Sprintf("--port=%s", seps[1]))
if strings.Contains(d.Addr, "/") {
args = append(args, fmt.Sprintf("--socket=%s", d.Addr))
} else {
seps := strings.SplitN(d.Addr, ":", 2)
args = append(args, fmt.Sprintf("--host=%s", seps[0]))
if len(seps) > 1 {
args = append(args, fmt.Sprintf("--port=%s", seps[1]))
}
}
args = append(args, fmt.Sprintf("--user=%s", d.User))
args = append(args, fmt.Sprintf("--password=%s", d.Password))
args = append(args, "--master-data")
if !d.masterDataSkipped {
args = append(args, "--master-data")
}
if d.maxAllowedPacket > 0 {
// mysqldump param should be --max-allowed-packet=%dM not be --max_allowed_packet=%dM
args = append(args, fmt.Sprintf("--max-allowed-packet=%dM", d.maxAllowedPacket))
}
args = append(args, "--single-transaction")
args = append(args, "--skip-lock-tables")
@ -112,12 +157,25 @@ func (d *Dumper) Dump(w io.Writer) error {
// Multi row is easy for us to parse the data
args = append(args, "--skip-extended-insert")
if d.hexBlob {
// Use hex for the binary type
args = append(args, "--hex-blob")
}
for db, tables := range d.IgnoreTables {
for _, table := range tables {
args = append(args, fmt.Sprintf("--ignore-table=%s.%s", db, table))
}
}
if len(d.Charset) != 0 {
args = append(args, fmt.Sprintf("--default-character-set=%s", d.Charset))
}
if len(d.Where) != 0 {
args = append(args, fmt.Sprintf("--where=%s", d.Where))
}
if len(d.Tables) == 0 && len(d.Databases) == 0 {
args = append(args, "--all-databases")
} else if len(d.Tables) == 0 {
@ -133,6 +191,7 @@ func (d *Dumper) Dump(w io.Writer) error {
w.Write([]byte(fmt.Sprintf("USE `%s`;\n", d.TableDB)))
}
log.Infof("exec mysqldump with %v", args)
cmd := exec.Command(d.ExecutionPath, args...)
cmd.Stderr = d.ErrOut
@ -147,7 +206,7 @@ func (d *Dumper) DumpAndParse(h ParseHandler) error {
done := make(chan error, 1)
go func() {
err := Parse(r, h)
err := Parse(r, h, !d.masterDataSkipped)
r.CloseWithError(err)
done <- err
}()

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io/ioutil"
"os"
"strings"
"testing"
. "github.com/pingcap/check"
@ -38,6 +39,7 @@ func (s *schemaTestSuite) SetUpSuite(c *C) {
c.Assert(err, IsNil)
c.Assert(s.d, NotNil)
s.d.SetCharset("utf8")
s.d.SetErrOut(os.Stderr)
_, err = s.conn.Execute("CREATE DATABASE IF NOT EXISTS test1")
@ -177,7 +179,7 @@ func (s *schemaTestSuite) TestParse(c *C) {
err := s.d.Dump(&buf)
c.Assert(err, IsNil)
err = Parse(&buf, new(testParseHandler))
err = Parse(&buf, new(testParseHandler), true)
c.Assert(err, IsNil)
}
@ -196,3 +198,29 @@ func (s *parserTestSuite) TestParseValue(c *C) {
values, err = parseValues(str)
c.Assert(err, NotNil)
}
func (s *parserTestSuite) TestParseLine(c *C) {
lines := []struct {
line string
expected string
}{
{line: "INSERT INTO `test` VALUES (1, 'first', 'hello mysql; 2', 'e1', 'a,b');",
expected: "1, 'first', 'hello mysql; 2', 'e1', 'a,b'"},
{line: "INSERT INTO `test` VALUES (0x22270073646661736661736466, 'first', 'hello mysql; 2', 'e1', 'a,b');",
expected: "0x22270073646661736661736466, 'first', 'hello mysql; 2', 'e1', 'a,b'"},
}
f := func(c rune) bool {
return c == '\r' || c == '\n'
}
for _, t := range lines {
l := strings.TrimRightFunc(t.line, f)
m := valuesExp.FindAllStringSubmatch(l, -1)
c.Assert(m, HasLen, 1)
c.Assert(m[0][1], Matches, "test")
c.Assert(m[0][2], Matches, t.expected)
}
}

View File

@ -6,6 +6,7 @@ import (
"io"
"regexp"
"strconv"
"strings"
"github.com/juju/errors"
"github.com/siddontang/go-mysql/mysql"
@ -34,7 +35,7 @@ func init() {
// Parse the dump data with Dumper generate.
// It can not parse all the data formats with mysqldump outputs
func Parse(r io.Reader, h ParseHandler) error {
func Parse(r io.Reader, h ParseHandler, parseBinlogPos bool) error {
rb := bufio.NewReaderSize(r, 1024*16)
var db string
@ -48,9 +49,12 @@ func Parse(r io.Reader, h ParseHandler) error {
break
}
line = line[0 : len(line)-1]
// Ignore '\n' on Linux or '\r\n' on Windows
line = strings.TrimRightFunc(line, func(c rune) bool {
return c == '\r' || c == '\n'
})
if !binlogParsed {
if parseBinlogPos && !binlogParsed {
if m := binlogExp.FindAllStringSubmatch(line, -1); len(m) == 1 {
name := m[0][1]
pos, err := strconv.ParseUint(m[0][2], 10, 64)

View File

@ -46,12 +46,12 @@ func (h *MariadbGTIDHandler) FindBestSlaves(slaves []*Server) ([]*Server, error)
if len(str) == 0 {
seq = 0
} else {
g, err := ParseMariadbGTIDSet(str)
g, err := ParseMariadbGTID(str)
if err != nil {
return nil, errors.Trace(err)
}
seq = g.(MariadbGTID).SequenceNumber
seq = g.SequenceNumber
}
ps[i] = seq
@ -118,7 +118,7 @@ func (h *MariadbGTIDHandler) WaitRelayLogDone(s *Server) error {
fname, _ := r.GetStringByName(0, "Master_Log_File")
pos, _ := r.GetIntByName(0, "Read_Master_Log_Pos")
return s.MasterPosWait(Position{fname, uint32(pos)}, 0)
return s.MasterPosWait(Position{Name: fname, Pos: uint32(pos)}, 0)
}
func (h *MariadbGTIDHandler) WaitCatchMaster(s *Server, m *Server) error {

View File

@ -152,7 +152,7 @@ func (s *Server) FetchSlaveReadPos() (Position, error) {
fname, _ := r.GetStringByName(0, "Master_Log_File")
pos, _ := r.GetIntByName(0, "Read_Master_Log_Pos")
return Position{fname, uint32(pos)}, nil
return Position{Name: fname, Pos: uint32(pos)}, nil
}
// Get current executed binlog filename and position from master
@ -165,7 +165,7 @@ func (s *Server) FetchSlaveExecutePos() (Position, error) {
fname, _ := r.GetStringByName(0, "Relay_Master_Log_File")
pos, _ := r.GetIntByName(0, "Exec_Master_Log_Pos")
return Position{fname, uint32(pos)}, nil
return Position{Name: fname, Pos: uint32(pos)}, nil
}
func (s *Server) MasterPosWait(pos Position, timeout int) error {

View File

@ -1,30 +0,0 @@
hash: 1a3d05afef96cd7601a004e573b128db9051eecd1b5d0a3d69d3fa1ee1a3e3b8
updated: 2016-09-03T12:30:00.028685232+08:00
imports:
- name: github.com/BurntSushi/toml
version: 056c9bc7be7190eaa7715723883caffa5f8fa3e4
- name: github.com/go-sql-driver/mysql
version: 3654d25ec346ee8ce71a68431025458d52a38ac0
- name: github.com/jmoiron/sqlx
version: 54aec3fd91a2b2129ffaca0d652b8a9223ee2d9e
subpackages:
- reflectx
- name: github.com/juju/errors
version: 6f54ff6318409d31ff16261533ce2c8381a4fd5d
- name: github.com/ngaut/log
version: cec23d3e10b016363780d894a0eb732a12c06e02
- name: github.com/pingcap/check
version: ce8a2f822ab1e245a4eefcef2996531c79c943f1
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/siddontang/go
version: 354e14e6c093c661abb29fd28403b3c19cff5514
subpackages:
- hack
- ioutil2
- sync2
- name: golang.org/x/net
version: 6acef71eb69611914f7a30939ea9f6e194c78172
subpackages:
- context
testImports: []

View File

@ -1,26 +0,0 @@
package: github.com/siddontang/go-mysql
import:
- package: github.com/BurntSushi/toml
version: 056c9bc7be7190eaa7715723883caffa5f8fa3e4
- package: github.com/go-sql-driver/mysql
version: 3654d25ec346ee8ce71a68431025458d52a38ac0
- package: github.com/jmoiron/sqlx
version: 54aec3fd91a2b2129ffaca0d652b8a9223ee2d9e
subpackages:
- reflectx
- package: github.com/juju/errors
version: 6f54ff6318409d31ff16261533ce2c8381a4fd5d
- package: github.com/ngaut/log
version: cec23d3e10b016363780d894a0eb732a12c06e02
- package: github.com/pingcap/check
version: ce8a2f822ab1e245a4eefcef2996531c79c943f1
- package: github.com/satori/go.uuid
version: ^1.1.0
- package: github.com/siddontang/go
version: 354e14e6c093c661abb29fd28403b3c19cff5514
subpackages:
- hack
- ioutil2
- sync2
- package: golang.org/x/net
version: 6acef71eb69611914f7a30939ea9f6e194c78172

View File

@ -6,16 +6,22 @@ const (
TimeFormat string = "2006-01-02 15:04:05"
)
var (
// maybe you can change for your specified name
ServerVersion string = "go-mysql-0.1"
)
const (
OK_HEADER byte = 0x00
MORE_DATE_HEADER byte = 0x01
ERR_HEADER byte = 0xff
EOF_HEADER byte = 0xfe
LocalInFile_HEADER byte = 0xfb
CACHE_SHA2_FAST_AUTH byte = 0x03
CACHE_SHA2_FULL_AUTH byte = 0x04
)
const (
AUTH_MYSQL_OLD_PASSWORD = "mysql_old_password"
AUTH_NATIVE_PASSWORD = "mysql_native_password"
AUTH_CACHING_SHA2_PASSWORD = "caching_sha2_password"
AUTH_SHA256_PASSWORD = "sha256_password"
)
const (
@ -151,7 +157,6 @@ const (
)
const (
AUTH_NAME = "mysql_native_password"
DEFAULT_CHARSET = "utf8"
DEFAULT_COLLATION_ID uint8 = 33
DEFAULT_COLLATION_NAME string = "utf8_general_ci"

View File

@ -57,3 +57,10 @@ func NewError(errCode uint16, message string) *MyError {
return e
}
func ErrorCode(errMsg string) (code int) {
var tmpStr string
// golang scanf doesn't support %*,so I used a temporary variable
fmt.Sscanf(errMsg, "%s%d", &tmpStr, &code)
return
}

View File

@ -31,42 +31,42 @@ func (p FieldData) Parse() (f *Field, err error) {
var n int
pos := 0
//skip catelog, always def
n, err = SkipLengthEnodedString(p)
n, err = SkipLengthEncodedString(p)
if err != nil {
return
}
pos += n
//schema
f.Schema, _, n, err = LengthEnodedString(p[pos:])
f.Schema, _, n, err = LengthEncodedString(p[pos:])
if err != nil {
return
}
pos += n
//table
f.Table, _, n, err = LengthEnodedString(p[pos:])
f.Table, _, n, err = LengthEncodedString(p[pos:])
if err != nil {
return
}
pos += n
//org_table
f.OrgTable, _, n, err = LengthEnodedString(p[pos:])
f.OrgTable, _, n, err = LengthEncodedString(p[pos:])
if err != nil {
return
}
pos += n
//name
f.Name, _, n, err = LengthEnodedString(p[pos:])
f.Name, _, n, err = LengthEncodedString(p[pos:])
if err != nil {
return
}
pos += n
//org_name
f.OrgName, _, n, err = LengthEnodedString(p[pos:])
f.OrgName, _, n, err = LengthEncodedString(p[pos:])
if err != nil {
return
}

View File

@ -11,6 +11,10 @@ type GTIDSet interface {
Equal(o GTIDSet) bool
Contain(o GTIDSet) bool
Update(GTIDStr string) error
Clone() GTIDSet
}
func ParseGTIDSet(flavor string, s string) (GTIDSet, error) {

View File

@ -1,28 +1,32 @@
package mysql
import (
"bytes"
"fmt"
"strconv"
"strings"
"github.com/juju/errors"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go/hack"
)
// MariadbGTID represent mariadb gtid, [domain ID]-[server-id]-[sequence]
type MariadbGTID struct {
DomainID uint32
ServerID uint32
SequenceNumber uint64
}
// We don't support multi source replication, so the mariadb gtid set may have only domain-server-sequence
func ParseMariadbGTIDSet(str string) (GTIDSet, error) {
// ParseMariadbGTID parses mariadb gtid, [domain ID]-[server-id]-[sequence]
func ParseMariadbGTID(str string) (*MariadbGTID, error) {
if len(str) == 0 {
return MariadbGTID{0, 0, 0}, nil
return &MariadbGTID{0, 0, 0}, nil
}
seps := strings.Split(str, "-")
var gtid MariadbGTID
gtid := new(MariadbGTID)
if len(seps) != 3 {
return gtid, errors.Errorf("invalid Mariadb GTID %v, must domain-server-sequence", str)
@ -43,13 +47,13 @@ func ParseMariadbGTIDSet(str string) (GTIDSet, error) {
return gtid, errors.Errorf("invalid MariaDB GTID Sequence number (%v): %v", seps[2], err)
}
return MariadbGTID{
return &MariadbGTID{
DomainID: uint32(domainID),
ServerID: uint32(serverID),
SequenceNumber: sequenceID}, nil
}
func (gtid MariadbGTID) String() string {
func (gtid *MariadbGTID) String() string {
if gtid.DomainID == 0 && gtid.ServerID == 0 && gtid.SequenceNumber == 0 {
return ""
}
@ -57,24 +61,172 @@ func (gtid MariadbGTID) String() string {
return fmt.Sprintf("%d-%d-%d", gtid.DomainID, gtid.ServerID, gtid.SequenceNumber)
}
func (gtid MariadbGTID) Encode() []byte {
return []byte(gtid.String())
}
func (gtid MariadbGTID) Equal(o GTIDSet) bool {
other, ok := o.(MariadbGTID)
if !ok {
return false
}
return gtid == other
}
func (gtid MariadbGTID) Contain(o GTIDSet) bool {
other, ok := o.(MariadbGTID)
if !ok {
return false
}
// Contain return whether one mariadb gtid covers another mariadb gtid
func (gtid *MariadbGTID) Contain(other *MariadbGTID) bool {
return gtid.DomainID == other.DomainID && gtid.SequenceNumber >= other.SequenceNumber
}
// Clone clones a mariadb gtid
func (gtid *MariadbGTID) Clone() *MariadbGTID {
o := new(MariadbGTID)
*o = *gtid
return o
}
func (gtid *MariadbGTID) forward(newer *MariadbGTID) error {
if newer.DomainID != gtid.DomainID {
return errors.Errorf("%s is not same with doamin of %s", newer, gtid)
}
/*
Here's a simplified example of binlog events.
Although I think one domain should have only one update at same time, we can't limit the user's usage.
we just output a warn log and let it go on
| mysqld-bin.000001 | 1453 | Gtid | 112 | 1495 | BEGIN GTID 0-112-6 |
| mysqld-bin.000001 | 1624 | Xid | 112 | 1655 | COMMIT xid=74 |
| mysqld-bin.000001 | 1655 | Gtid | 112 | 1697 | BEGIN GTID 0-112-7 |
| mysqld-bin.000001 | 1826 | Xid | 112 | 1857 | COMMIT xid=75 |
| mysqld-bin.000001 | 1857 | Gtid | 111 | 1899 | BEGIN GTID 0-111-5 |
| mysqld-bin.000001 | 1981 | Xid | 111 | 2012 | COMMIT xid=77 |
| mysqld-bin.000001 | 2012 | Gtid | 112 | 2054 | BEGIN GTID 0-112-8 |
| mysqld-bin.000001 | 2184 | Xid | 112 | 2215 | COMMIT xid=116 |
| mysqld-bin.000001 | 2215 | Gtid | 111 | 2257 | BEGIN GTID 0-111-6 |
*/
if newer.SequenceNumber <= gtid.SequenceNumber {
log.Warnf("out of order binlog appears with gtid %s vs current position gtid %s", newer, gtid)
}
gtid.ServerID = newer.ServerID
gtid.SequenceNumber = newer.SequenceNumber
return nil
}
// MariadbGTIDSet is a set of mariadb gtid
type MariadbGTIDSet struct {
Sets map[uint32]*MariadbGTID
}
// ParseMariadbGTIDSet parses str into mariadb gtid sets
func ParseMariadbGTIDSet(str string) (GTIDSet, error) {
s := new(MariadbGTIDSet)
s.Sets = make(map[uint32]*MariadbGTID)
if str == "" {
return s, nil
}
sp := strings.Split(str, ",")
//todo, handle redundant same uuid
for i := 0; i < len(sp); i++ {
err := s.Update(sp[i])
if err != nil {
return nil, errors.Trace(err)
}
}
return s, nil
}
// AddSet adds mariadb gtid into mariadb gtid set
func (s *MariadbGTIDSet) AddSet(gtid *MariadbGTID) error {
if gtid == nil {
return nil
}
o, ok := s.Sets[gtid.DomainID]
if ok {
err := o.forward(gtid)
if err != nil {
return errors.Trace(err)
}
} else {
s.Sets[gtid.DomainID] = gtid
}
return nil
}
// Update updates mariadb gtid set
func (s *MariadbGTIDSet) Update(GTIDStr string) error {
gtid, err := ParseMariadbGTID(GTIDStr)
if err != nil {
return err
}
err = s.AddSet(gtid)
return errors.Trace(err)
}
func (s *MariadbGTIDSet) String() string {
return hack.String(s.Encode())
}
// Encode encodes mariadb gtid set
func (s *MariadbGTIDSet) Encode() []byte {
var buf bytes.Buffer
sep := ""
for _, gtid := range s.Sets {
buf.WriteString(sep)
buf.WriteString(gtid.String())
sep = ","
}
return buf.Bytes()
}
// Clone clones a mariadb gtid set
func (s *MariadbGTIDSet) Clone() GTIDSet {
clone := &MariadbGTIDSet{
Sets: make(map[uint32]*MariadbGTID),
}
for domainID, gtid := range s.Sets {
clone.Sets[domainID] = gtid.Clone()
}
return clone
}
// Equal returns true if two mariadb gtid set is same, otherwise return false
func (s *MariadbGTIDSet) Equal(o GTIDSet) bool {
other, ok := o.(*MariadbGTIDSet)
if !ok {
return false
}
if len(other.Sets) != len(s.Sets) {
return false
}
for domainID, gtid := range other.Sets {
o, ok := s.Sets[domainID]
if !ok {
return false
}
if *gtid != *o {
return false
}
}
return true
}
// Contain return whether one mariadb gtid set covers another mariadb gtid set
func (s *MariadbGTIDSet) Contain(o GTIDSet) bool {
other, ok := o.(*MariadbGTIDSet)
if !ok {
return false
}
for doaminID, gtid := range other.Sets {
o, ok := s.Sets[doaminID]
if !ok {
return false
}
if !o.Contain(gtid) {
return false
}
}
return true
}

View File

@ -97,7 +97,11 @@ func (s IntervalSlice) Normalize() IntervalSlice {
n = append(n, s[i])
continue
} else {
n[len(n)-1] = Interval{last.Start, s[i].Stop}
stop := s[i].Stop
if last.Stop > stop {
stop = last.Stop
}
n[len(n)-1] = Interval{last.Start, stop}
}
}
@ -285,17 +289,28 @@ func (s *UUIDSet) Decode(data []byte) error {
return err
}
func (s *UUIDSet) Clone() *UUIDSet {
clone := new(UUIDSet)
clone.SID, _ = uuid.FromString(s.SID.String())
clone.Intervals = s.Intervals.Normalize()
return clone
}
type MysqlGTIDSet struct {
Sets map[string]*UUIDSet
}
func ParseMysqlGTIDSet(str string) (GTIDSet, error) {
s := new(MysqlGTIDSet)
s.Sets = make(map[string]*UUIDSet)
if str == "" {
return s, nil
}
sp := strings.Split(str, ",")
s.Sets = make(map[string]*UUIDSet, len(sp))
//todo, handle redundant same uuid
for i := 0; i < len(sp); i++ {
if set, err := ParseUUIDSet(sp[i]); err != nil {
@ -334,6 +349,9 @@ func DecodeMysqlGTIDSet(data []byte) (*MysqlGTIDSet, error) {
}
func (s *MysqlGTIDSet) AddSet(set *UUIDSet) {
if set == nil {
return
}
sid := set.SID.String()
o, ok := s.Sets[sid]
if ok {
@ -343,6 +361,17 @@ func (s *MysqlGTIDSet) AddSet(set *UUIDSet) {
}
}
func (s *MysqlGTIDSet) Update(GTIDStr string) error {
uuidSet, err := ParseUUIDSet(GTIDStr)
if err != nil {
return err
}
s.AddSet(uuidSet)
return nil
}
func (s *MysqlGTIDSet) Contain(o GTIDSet) bool {
sub, ok := o.(*MysqlGTIDSet)
if !ok {
@ -407,3 +436,14 @@ func (s *MysqlGTIDSet) Encode() []byte {
return buf.Bytes()
}
func (gtid *MysqlGTIDSet) Clone() GTIDSet {
clone := &MysqlGTIDSet{
Sets: make(map[string]*UUIDSet),
}
for sid, uuidSet := range gtid.Sets {
clone.Sets[sid] = uuidSet.Clone()
}
return clone
}

View File

@ -1,6 +1,7 @@
package mysql
import (
"strings"
"testing"
"github.com/pingcap/check"
@ -15,11 +16,11 @@ type mysqlTestSuite struct {
var _ = check.Suite(&mysqlTestSuite{})
func (s *mysqlTestSuite) SetUpSuite(c *check.C) {
func (t *mysqlTestSuite) SetUpSuite(c *check.C) {
}
func (s *mysqlTestSuite) TearDownSuite(c *check.C) {
func (t *mysqlTestSuite) TearDownSuite(c *check.C) {
}
@ -59,6 +60,12 @@ func (t *mysqlTestSuite) TestMysqlGTIDIntervalSlice(c *check.C) {
n = i.Normalize()
c.Assert(n, check.DeepEquals, IntervalSlice{Interval{1, 3}, Interval{4, 5}})
i = IntervalSlice{Interval{1, 4}, Interval{2, 3}}
i.Sort()
c.Assert(i, check.DeepEquals, IntervalSlice{Interval{1, 4}, Interval{2, 3}})
n = i.Normalize()
c.Assert(n, check.DeepEquals, IntervalSlice{Interval{1, 4}})
n1 := IntervalSlice{Interval{1, 3}, Interval{4, 5}}
n2 := IntervalSlice{Interval{1, 2}}
@ -91,6 +98,15 @@ func (t *mysqlTestSuite) TestMysqlGTIDCodec(c *check.C) {
c.Assert(gs, check.DeepEquals, o)
}
func (t *mysqlTestSuite) TestMysqlUpdate(c *check.C) {
g1, err := ParseMysqlGTIDSet("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57")
c.Assert(err, check.IsNil)
g1.Update("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-58")
c.Assert(strings.ToUpper(g1.String()), check.Equals, "3E11FA47-71CA-11E1-9E33-C80AA9429562:21-58")
}
func (t *mysqlTestSuite) TestMysqlGTIDContain(c *check.C) {
g1, err := ParseMysqlGTIDSet("3E11FA47-71CA-11E1-9E33-C80AA9429562:23")
c.Assert(err, check.IsNil)
@ -151,3 +167,26 @@ func (t *mysqlTestSuite) TestMysqlParseBinaryUint64(c *check.C) {
u64 := ParseBinaryUint64([]byte{1, 2, 3, 4, 5, 6, 7, 128})
c.Assert(u64, check.Equals, 128*uint64(72057594037927936)+7*uint64(281474976710656)+6*uint64(1099511627776)+5*uint64(4294967296)+4*16777216+3*65536+2*256+1)
}
func (t *mysqlTestSuite) TestErrorCode(c *check.C) {
tbls := []struct {
msg string
code int
}{
{"ERROR 1094 (HY000): Unknown thread id: 1094", 1094},
{"error string", 0},
{"abcdefg", 0},
{"123455 ks094", 0},
{"ERROR 1046 (3D000): Unknown error 1046", 1046},
}
for _, v := range tbls {
c.Assert(ErrorCode(v.msg), check.Equals, v.code)
}
}
func (t *mysqlTestSuite) TestMysqlNullDecode(c *check.C) {
_, isNull, n := LengthEncodedInt([]byte{0xfb})
c.Assert(isNull, check.IsTrue)
c.Assert(n, check.Equals, 1)
}

View File

@ -28,7 +28,7 @@ func (p RowData) ParseText(f []*Field) ([]interface{}, error) {
var n int = 0
for i := range f {
v, isNull, n, err = LengthEnodedString(p[pos:])
v, isNull, n, err = LengthEncodedString(p[pos:])
if err != nil {
return nil, errors.Trace(err)
}
@ -115,7 +115,8 @@ func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) {
} else {
data[i] = ParseBinaryInt24(p[pos : pos+3])
}
pos += 4
//3 byte
pos += 3
continue
case MYSQL_TYPE_LONG:
@ -150,7 +151,7 @@ func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) {
MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB,
MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY:
v, isNull, n, err = LengthEnodedString(p[pos:])
v, isNull, n, err = LengthEncodedString(p[pos:])
pos += n
if err != nil {
return nil, errors.Trace(err)

View File

@ -38,6 +38,8 @@ func formatTextValue(value interface{}) ([]byte, error) {
return v, nil
case string:
return hack.Slice(v), nil
case nil:
return nil, nil
default:
return nil, errors.Errorf("invalid type %T", value)
}
@ -77,23 +79,40 @@ func formatBinaryValue(value interface{}) ([]byte, error) {
return nil, errors.Errorf("invalid type %T", value)
}
}
func fieldType(value interface{}) (typ uint8, err error) {
switch value.(type) {
case int8, int16, int32, int64, int:
typ = MYSQL_TYPE_LONGLONG
case uint8, uint16, uint32, uint64, uint:
typ = MYSQL_TYPE_LONGLONG
case float32, float64:
typ = MYSQL_TYPE_DOUBLE
case string, []byte:
typ = MYSQL_TYPE_VAR_STRING
case nil:
typ = MYSQL_TYPE_NULL
default:
err = errors.Errorf("unsupport type %T for resultset", value)
}
return
}
func formatField(field *Field, value interface{}) error {
switch value.(type) {
case int8, int16, int32, int64, int:
field.Charset = 63
field.Type = MYSQL_TYPE_LONGLONG
field.Flag = BINARY_FLAG | NOT_NULL_FLAG
case uint8, uint16, uint32, uint64, uint:
field.Charset = 63
field.Type = MYSQL_TYPE_LONGLONG
field.Flag = BINARY_FLAG | NOT_NULL_FLAG | UNSIGNED_FLAG
case float32, float64:
field.Charset = 63
field.Type = MYSQL_TYPE_DOUBLE
field.Flag = BINARY_FLAG | NOT_NULL_FLAG
case string, []byte:
field.Charset = 33
field.Type = MYSQL_TYPE_VAR_STRING
case nil:
field.Charset = 33
default:
return errors.Errorf("unsupport type %T for resultset", value)
}
@ -106,7 +125,13 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse
r.Fields = make([]*Field, len(names))
var b []byte
var err error
if len(values) == 0 {
for i, name := range names {
r.Fields[i] = &Field{Name: hack.Slice(name), Charset: 33, Type: MYSQL_TYPE_NULL}
}
return r, nil
}
for i, vs := range values {
if len(vs) != len(r.Fields) {
@ -115,13 +140,23 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse
var row []byte
for j, value := range vs {
if i == 0 {
field := &Field{}
r.Fields[j] = field
field.Name = hack.Slice(names[j])
if err = formatField(field, value); err != nil {
return nil, errors.Trace(err)
typ, err := fieldType(value)
if err != nil {
return nil, errors.Trace(err)
}
if r.Fields[j] == nil {
r.Fields[j] = &Field{Name: hack.Slice(names[j]), Type: typ}
formatField(r.Fields[j], value)
} else if typ != r.Fields[j].Type {
// we got another type in the same column. in general, we treat it as an error, except
// the case, when old value was null, and the new one isn't null, so we can update
// type info for fields.
oldIsNull, newIsNull := r.Fields[j].Type == MYSQL_TYPE_NULL, typ == MYSQL_TYPE_NULL
if oldIsNull && !newIsNull { // old is null, new isn't, update type info.
r.Fields[j].Type = typ
formatField(r.Fields[j], value)
} else if !oldIsNull && !newIsNull { // different non-null types, that's an error.
return nil, errors.Errorf("row types aren't consistent")
}
}
b, err = formatTextValue(value)
@ -130,7 +165,12 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse
return nil, errors.Trace(err)
}
row = append(row, PutLengthEncodedString(b)...)
if b == nil {
// NULL value is encoded as 0xfb here (without additional info about length)
row = append(row, 0xfb)
} else {
row = append(row, PutLengthEncodedString(b)...)
}
}
r.RowDatas = append(r.RowDatas, row)
@ -145,7 +185,6 @@ func BuildSimpleBinaryResultset(names []string, values [][]interface{}) (*Result
r.Fields = make([]*Field, len(names))
var b []byte
var err error
bitmapLen := ((len(names) + 7 + 2) >> 3)
@ -161,8 +200,12 @@ func BuildSimpleBinaryResultset(names []string, values [][]interface{}) (*Result
row = append(row, nullBitmap...)
for j, value := range vs {
typ, err := fieldType(value)
if err != nil {
return nil, errors.Trace(err)
}
if i == 0 {
field := &Field{}
field := &Field{Type: typ}
r.Fields[j] = field
field.Name = hack.Slice(names[j])

View File

@ -11,6 +11,8 @@ import (
"github.com/juju/errors"
"github.com/siddontang/go/hack"
"crypto/sha256"
"crypto/rsa"
)
func Pstack() string {
@ -48,6 +50,62 @@ func CalcPassword(scramble, password []byte) []byte {
return scramble
}
// Hash password using MySQL 8+ method (SHA256)
func CalcCachingSha2Password(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
crypt := sha256.New()
crypt.Write([]byte(password))
message1 := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1)
message1Hash := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1Hash)
crypt.Write(scramble)
message2 := crypt.Sum(nil)
for i := range message1 {
message1[i] ^= message2[i]
}
return message1
}
func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
plain := make([]byte, len(password)+1)
copy(plain, password)
for i := range plain {
j := i % len(seed)
plain[i] ^= seed[j]
}
sha1v := sha1.New()
return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil)
}
// encodes a uint64 value and appends it to the given bytes slice
func AppendLengthEncodedInteger(b []byte, n uint64) []byte {
switch {
case n <= 250:
return append(b, byte(n))
case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8))
case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
}
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}
func RandomBuf(size int) ([]byte, error) {
buf := make([]byte, size)
@ -84,39 +142,33 @@ func BFixedLengthInt(buf []byte) uint64 {
}
func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
switch b[0] {
if len(b) == 0 {
return 0, true, 1
}
switch b[0] {
// 251: NULL
case 0xfb:
n = 1
isNull = true
return
return 0, true, 1
// 252: value of following 2
// 252: value of following 2
case 0xfc:
num = uint64(b[1]) | uint64(b[2])<<8
n = 3
return
return uint64(b[1]) | uint64(b[2])<<8, false, 3
// 253: value of following 3
// 253: value of following 3
case 0xfd:
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
n = 4
return
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
// 254: value of following 8
// 254: value of following 8
case 0xfe:
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56
n = 9
return
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
}
// 0-250: value of first byte
num = uint64(b[0])
n = 1
return
return uint64(b[0]), false, 1
}
func PutLengthEncodedInt(n uint64) []byte {
@ -137,23 +189,26 @@ func PutLengthEncodedInt(n uint64) []byte {
return nil
}
func LengthEnodedString(b []byte) ([]byte, bool, int, error) {
// returns the string read as a bytes slice, whether the value is NULL,
// the number of bytes read and an error, in case the string is longer than
// the input slice
func LengthEncodedString(b []byte) ([]byte, bool, int, error) {
// Get length
num, isNull, n := LengthEncodedInt(b)
if num < 1 {
return nil, isNull, n, nil
return b[n:n], isNull, n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return b[n-int(num) : n], false, n, nil
return b[n-int(num) : n : n], false, n, nil
}
return nil, false, n, io.EOF
}
func SkipLengthEnodedString(b []byte) (int, error) {
func SkipLengthEncodedString(b []byte) (int, error) {
// Get length
num, _, n := LengthEncodedInt(b)
if num < 1 {

Some files were not shown because too many files have changed in this diff Show More