updated go-mysql

This commit is contained in:
Shlomi Noach 2019-01-01 10:58:12 +02:00
parent 255314927d
commit 14e396ed73
107 changed files with 21297 additions and 0 deletions

78
vendor/github.com/siddontang/go-mysql/Gopkg.lock generated vendored Normal file
View File

@ -0,0 +1,78 @@
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]]
name = "github.com/BurntSushi/toml"
packages = ["."]
revision = "b26d9c308763d68093482582cea63d69be07a0f0"
version = "v0.3.0"
[[projects]]
branch = "master"
name = "github.com/go-sql-driver/mysql"
packages = ["."]
revision = "99ff426eb706cffe92ff3d058e168b278cabf7c7"
[[projects]]
branch = "master"
name = "github.com/jmoiron/sqlx"
packages = [
".",
"reflectx"
]
revision = "2aeb6a910c2b94f2d5eb53d9895d80e27264ec41"
[[projects]]
branch = "master"
name = "github.com/juju/errors"
packages = ["."]
revision = "c7d06af17c68cd34c835053720b21f6549d9b0ee"
[[projects]]
branch = "master"
name = "github.com/pingcap/check"
packages = ["."]
revision = "1c287c953996ab3a0bf535dba9d53d809d3dc0b6"
[[projects]]
name = "github.com/satori/go.uuid"
packages = ["."]
revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3"
version = "v1.2.0"
[[projects]]
name = "github.com/shopspring/decimal"
packages = ["."]
revision = "cd690d0c9e2447b1ef2a129a6b7b49077da89b8e"
version = "1.1.0"
[[projects]]
branch = "master"
name = "github.com/siddontang/go"
packages = [
"hack",
"sync2"
]
revision = "2b7082d296ba89ae7ead0f977816bddefb65df9d"
[[projects]]
branch = "master"
name = "github.com/siddontang/go-log"
packages = [
"log",
"loggers"
]
revision = "a4d157e46fa3e08b7e7ff329af341fa3ff86c02c"
[[projects]]
name = "google.golang.org/appengine"
packages = ["cloudsql"]
revision = "b1f26356af11148e710935ed1ac8a7f5702c7612"
version = "v1.1.0"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "a1f9939938a58551bbb3f19411c9d1386995d36296de6f6fb5d858f5923db85e"
solver-name = "gps-cdcl"
solver-version = 1

56
vendor/github.com/siddontang/go-mysql/Gopkg.toml generated vendored Normal file
View File

@ -0,0 +1,56 @@
# Gopkg.toml example
#
# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html
# for detailed Gopkg.toml documentation.
#
# required = ["github.com/user/thing/cmd/thing"]
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
#
# [[constraint]]
# name = "github.com/user/project"
# version = "1.0.0"
#
# [[constraint]]
# name = "github.com/user/project2"
# branch = "dev"
# source = "github.com/myfork/project2"
#
# [[override]]
# name = "github.com/x/y"
# version = "2.4.0"
#
# [prune]
# non-go = false
# go-tests = true
# unused-packages = true
[[constraint]]
name = "github.com/BurntSushi/toml"
version = "v0.3.0"
[[constraint]]
name = "github.com/go-sql-driver/mysql"
branch = "master"
[[constraint]]
branch = "master"
name = "github.com/juju/errors"
[[constraint]]
name = "github.com/satori/go.uuid"
version = "v1.2.0"
[[constraint]]
name = "github.com/shopspring/decimal"
version = "v1.1.0"
[[constraint]]
branch = "master"
name = "github.com/siddontang/go"
[prune]
go-tests = true
unused-packages = true
non-go = true

6
vendor/github.com/siddontang/go-mysql/clear_vendor.sh generated vendored Executable file
View File

@ -0,0 +1,6 @@
find vendor \( -type f -or -type l \) -not -name "*.go" -not -name "LICENSE" -not -name "*.s" -not -name "PATENTS" -not -name "*.h" -not -name "*.c" | xargs -I {} rm {}
# delete all test files
find vendor -type f -name "*_generated.go" | xargs -I {} rm {}
find vendor -type f -name "*_test.go" | xargs -I {} rm {}
find vendor -type d -name "_vendor" | xargs -I {} rm -rf {}
find vendor -type d -empty | xargs -I {} rm -rf {}

28
vendor/github.com/siddontang/go-mysql/client/tls.go generated vendored Normal file
View File

@ -0,0 +1,28 @@
package client
import (
"crypto/tls"
"crypto/x509"
)
// generate TLS config for client side
// if insecureSkipVerify is set to true, serverName will not be validated
func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caPem) {
panic("failed to add ca PEM")
}
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
config := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: pool,
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
}
return config
}

View File

@ -0,0 +1,80 @@
version: '3'
services:
mysql-5.5.61:
image: "mysql:5.5.61"
container_name: "mysql-server-5.5.61"
ports:
- "5561:3306"
command: --ssl=TRUE --ssl-ca=/usr/local/mysql/ca.pem --ssl-cert=/usr/local/mysql/server-cert.pem --ssl-key=/usr/local/mysql/server-key.pem
volumes:
- ./resources/ca.pem:/usr/local/mysql/ca.pem
- ./resources/server-cert.pem:/usr/local/mysql/server-cert.pem
- ./resources/server-key.pem:/usr/local/mysql/server-key.pem
environment:
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- bind-address=0.0.0.0
mysql-5.6.41:
image: "mysql:5.6.41"
container_name: "mysql-server-5.6.41"
ports:
- "5641:3306"
command: --ssl=TRUE --ssl-ca=/usr/local/mysql/ca.pem --ssl-cert=/usr/local/mysql/server-cert.pem --ssl-key=/usr/local/mysql/server-key.pem
volumes:
- ./resources/ca.pem:/usr/local/mysql/ca.pem
- ./resources/server-cert.pem:/usr/local/mysql/server-cert.pem
- ./resources/server-key.pem:/usr/local/mysql/server-key.pem
environment:
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- bind-address=0.0.0.0
mysql-default:
image: "mysql:5.7.22"
container_name: "mysql-server-default"
ports:
- "3306:3306"
command: ["mysqld", "--log-bin=mysql-bin", "--server-id=1"]
environment:
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- bind-address=0.0.0.0
mysql-5.7.22:
image: "mysql:5.7.22"
container_name: "mysql-server-5.7.22"
ports:
- "5722:3306"
environment:
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- bind-address=0.0.0.0
mysql-8.0.3:
image: "mysql:8.0.3"
container_name: "mysql-server-8.0.3"
ports:
- "8003:3306"
environment:
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- bind-address=0.0.0.0
mysql-8.0.12:
image: "mysql:8.0.12"
container_name: "mysql-server-8.0.12"
ports:
- "8012:3306"
environment:
#- MYSQL_ROOT_PASSWORD=abc123
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- bind-address=0.0.0.0
mysql-8.0.12-sha256:
image: "mysql:8.0.12"
container_name: "mysql-server-8.0.12-sha256"
ports:
- "8013:3306"
entrypoint: ['/entrypoint.sh', '--default-authentication-plugin=sha256_password']
environment:
#- MYSQL_ROOT_PASSWORD=abc123
- MYSQL_ALLOW_EMPTY_PASSWORD=true
- bind-address=0.0.0.0

View File

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QL
suT1fGf5zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eo
xrS64GSbPVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKv
lpEAtRAe4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7
IeXUc456yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioO
NaePQKFR/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABAoIBAC2U0jponRiGmgIl
gohw6+D+6pNeaKAAUkwYbKXJZ3noWLFr4T3GDTg9WDqvcvJg+rT9NvZxdCW3tDc5
CVBcwO1g9PVcUEaRqcme3EhrxKdQQ76QmjUGeQf1ktd+YnmiZ1kOnGLtZ9/gsYpQ
06iGSIOX3+xA4BQOhEAPCOShMjYv+pWvWGhZCSmeulKulNVPBbG2H1I9EoT5Wd+Q
8LUfgZOuUXrtcsuvEf2XeacCo0pUbjx8ErhDHP6aPasFAXq15Bm8DnsUOrrsjcLy
sPy/mHwpd6kTw+O3EzjTdaYSFRoDSpfpIS5Bk+yicdxOmTwp1pzDu6HyYnuOnc9Q
JQ8HvlECgYEA2z+1HKVz5k7NYyRihW4l30vAcAGcgG1RObB6DmLbGu4MPvMymgLO
1QhYjlCcKfRHhVS2864op3Oba2fIgCc2am0DIQQ6kZ23ick78aj9G2ZXYpdpIPLu
Kl1AZHj6XDrOPVqidwcE6iYHLLWp9x4Atgw5d44XmhQ0kwrqAfccOX8CgYEAzxnl
7Uu+v5WI3hBVJpxS6eoS1TdztVIJaumyE43pBoHEuJrp4MRf0Lu2DiDpH8R3/RoE
o+ykn6xzphYwUopYaCWzYTKoXvxCvmLkDjHcpdzLtwWbKG+MJih2nTADEDI7sK4e
a3IU8miK6FeqkQHfs/5dlQa8q31yxiukw0qQEP0CgYAtLg6jTZD5l6mJUZkfx9f0
EMciDaLzcBN54Nz2E/b0sLNDUZhO1l9K1QJyqTfVCWqnlhJxWqU0BIW1d1iA2BPF
kJtBdX6gPTDyKs64eMtXlxpQzcSzLnxXrIm1apyk3tVbHU83WfHwUk/OLc1NiBg7
a394HIbOkHVZC7m3F/Xv/wKBgQDHrM2du8D+kJs0l4SxxFjAxPlBb8R01tLTrNwP
tGwu5OEZp+rE1jEXXFRMTPjXsyKI+hPtRJT4ilm6kXwnqNFSIL9RgHkLk6Z6T3hY
I0T8+ePD43jURLBYffzW0tqxO+2HDGmx6H0/twHuv89pHehkb2Qk8ijoIvyNCrlB
vVsntQKBgCK04nbb+G45D6TKCcZ6XKT/+qneJQE5cfvHl5EqrfjSmlnEUpJjJfyc
6Q1PtXtWOtOScU93f1JKL7+JBbWDn9uBlboM8BSkAVVd/2vyg88RuEtIru1syxcW
d1rMxqaMRJuhuqaS33CoPUpn30b4zVrPhQJ2+TwDAol4qIGHaie8
-----END RSA PRIVATE KEY-----

View File

@ -0,0 +1,22 @@
-----BEGIN CERTIFICATE-----
MIIDtTCCAp2gAwIBAgIJANeS1FOzWXlZMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTgwODE2MTUxNDE5WhcNMjEwNjA1MTUxNDE5WjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
CgKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QLsuT1fGf5
zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eoxrS64GSb
PVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKvlpEAtRAe
4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7IeXUc456
yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioONaePQKFR
/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABo4GnMIGkMB0GA1UdDgQWBBQgHiwD
00upIbCOunlK4HRw89DhjjB1BgNVHSMEbjBsgBQgHiwD00upIbCOunlK4HRw89Dh
jqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV
BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJANeS1FOzWXlZMAwGA1UdEwQF
MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAFMZFQTFKU5tWIpWh8BbVZeVZcng0Kiq
qwbhVwaTkqtfmbqw8/w+faOWylmLncQEMmgvnUltGMQlQKBwQM2byzPkz9phal3g
uI0JWJYqtcMyIQUB9QbbhrDNC9kdt/ji/x6rrIqzaMRuiBXqH5LQ9h856yXzArqd
cAQGzzYpbUCIv7ciSB93cKkU73fQLZVy5ZBy1+oAa1V9U4cb4G/20/PDmT+G3Gxz
pEjeDKtz8XINoWgA2cSdfAhNZt5vqJaCIZ8qN0z6C7SUKwUBderERUMLUXdhUldC
KTVHyEPvd0aULd5S5vEpKCnHcQmFcLdoN8t9k9pR9ZgwqXbyJHlxWFo=
-----END CERTIFICATE-----

View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDBjCCAe4CCQDg06wCf7hcuTANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB
VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0
cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDY0N1oXDTI4MDgxNjA4NDY0N1owRTELMAkG
A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0
IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AMmivNyk3Rc1ZvLPhb3WPNkf9f2G4g9nMc0+eMrR1IKJ1U1A98ojeIBT+pfk1bSq
Ol0UDm66Vd3YQ+4HpyYHaYV6mwoTEulL9Quk8RLa7TRwQu3PLi3o567RhVIrx8Z3
umuWb9UUzJfSFH04Uy9+By4CJCqIQXU4BocLIKHhIkNjmAQ9fWO1hZ8zmPHSEfvu
Wqa/DYKGvF0MJr4Lnkm/sKUd+O94p9suvwM6OGIDibACiKRF2H+JbgQLbA58zkLv
DHtXOqsCL7HxiONX8VDrQjN/66Nh9omk/Bx2Ec8IqappHvWf768HSH79x/znaial
VEV+6K0gP+voJHfnA10laWMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAPD+Fn1qj
HN62GD3eIgx6wJxYuemhdbgmEwrZZf4V70lS6e9Iloif0nBiISDxJUpXVWNRCN3Z
3QVC++F7deDmWL/3dSpXRvWsapzbCUhVQ2iBcnZ7QCOdvAqYR1ecZx70zvXCwBcd
6XKmRtdeNV6B211KRFmTYtVyPq4rcWrkTPGwPBncJI1eQQmyFv2T9SwVVp96Nbrq
sf7zrJGmuVCdXGPRi/ALVHtJCz6oPoft3I707eMe+ijnFqwGbmMD4fMD6Ync/hEz
PyR5FMZkXSXHS0gkA5pfwW7wJ2WSWDhI6JMS1gbatY7QzgHbKoQpxBPUXlnzzj2h
7O9cgFTh/XOZXQ==
-----END CERTIFICATE-----

View File

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAyaK83KTdFzVm8s+FvdY82R/1/YbiD2cxzT54ytHUgonVTUD3
yiN4gFP6l+TVtKo6XRQObrpV3dhD7genJgdphXqbChMS6Uv1C6TxEtrtNHBC7c8u
LejnrtGFUivHxne6a5Zv1RTMl9IUfThTL34HLgIkKohBdTgGhwsgoeEiQ2OYBD19
Y7WFnzOY8dIR++5apr8Ngoa8XQwmvgueSb+wpR3473in2y6/Azo4YgOJsAKIpEXY
f4luBAtsDnzOQu8Me1c6qwIvsfGI41fxUOtCM3/ro2H2iaT8HHYRzwipqmke9Z/v
rwdIfv3H/OdqJqVURX7orSA/6+gkd+cDXSVpYwIDAQABAoIBAAGLY5L1GFRzLkSx
3j5kA7dODV5RyC2CBtmhnt8+2DffwmiDFOLRfrzM5+B9+j0WCLhpzOqANuQqIesS
1+7so5xIIiPjnYN393qNWuNgFe0O5xRXP+1OGWg3ZqQIfdFBXYYxcs3ZCPAoxctn
wQteFcP+dDR3MrkpIrOqHCfhR5foieOMP+9k5kCjk+aZqhEmFyko+X+xVO/32xs+
+3qXhUrHt3Op5on30QMOFguniQlYwLJkd9qVjGuGMIrVPxoUz0rya4SKrGKgkAr8
mvQe2+sZo7cc6zC2ceaGMJU7z1RalTrCObbg5mynlu+Vf0E/YiES0abkQhSbcSB9
mAkJC7ECgYEA/H1NDEiO164yYK9ji4HM/8CmHegWS4qsgrzAs8lU0yAcgdg9e19A
mNi8yssfIBCw62RRE4UGWS5F82myhmvq/mXbf8eCJ2CMgdCHQh1rT7WFD/Uc5Pe/
8Lv2jNMQ61POguPyq6D0qcf8iigKIMHa1MIgAOmrgWrxulfbSUhm370CgYEAzHBu
J9p4dAqW32+Hrtv2XE0KUjH72TXr13WErosgeGTfsIW2exXByvLasxOJSY4Wb8oS
OLZ7bgp/EBchAc7my+nF8n5uOJxipWQUB5BoeB9aUJZ9AnWF4RDl94Jlm5PYBG/J
lRXrMtSTTIgmSw3Ft2A1vRMOQaHX89lNwOZL758CgYAXOT84/gOFexRPKFKzpkDA
1WtyHMLQN/UeIVZoMwCGWtHEb6tYCa7bYDQdQwmd3Wsoe5WpgfbPhR4SAYrWKl72
/09tNWCXVp4V4qRORH52Wm/ew+Dgfpk8/0zyLwfDXXYFPAo6Fxfp9ecYng4wbSQ/
pYtkChooUTniteoJl4s+0QKBgHbFEpoAqF3yEPi52L/TdmrlLwvVkhT86IkB8xVc
Kn8HS5VH+V3EpBN9x2SmAupCq/JCGRftnAOwAWWdqkVcqGTq6V8Z6HrnD8A6RhCm
6qpuvI94/iNBl4fLw25pyRH7cFITh68fTsb3DKQ3rNeJpsYEFPRFb9Ddb5JxOmTI
5nDNAoGBAM+SyOhUGU+0Uw2WJaGWzmEutjeMRr5Z+cZ8keC/ZJNdji/faaQoeOQR
OXI8O6RBTBwVNQMyDyttT8J8BkISwfAhSdPkjgPw9GZ1pGREl53uCFDIlX2nvtQM
ioNzG5WHB7Gd7eUUTA91kRF9MZJTHPqNiNGR0Udj/trGyGqJebni
-----END RSA PRIVATE KEY-----

View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDBjCCAe4CCQDg06wCf7hcuDANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB
VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0
cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDUyNVoXDTI4MDgxNjA4NDUyNVowRTELMAkG
A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0
IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
ALK2gqK4uvTlxJANO2JKdibvmh899z6oCo9Km0mz5unj4dpnq9hljsQuKtcHUcM4
HXcE06knaJ4TOF7lcsjaqoDO7r/SaFgjjXCqNvHD0Su4B+7qe52BZZTRV1AANP10
PvebarXSEzaZUCyHHhSF8+Qb4vX04XKX/TOqinTVGtlnduKzP5+qsaFBtpLAw1V0
At9EQB5BgnTYtdIsmvD4/2WhBvOjVxab75yx0R4oof4F3u528tbEegcWhBtmy2Xd
HI3S+TLljj3kOOdB+pgrVUl+KaDavWK3T+F1vTNDe56HEVNKeWlLy1scul61E0j9
IkZAu6aRDxtKdl7bKu0BkzMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAma3yFqR7
xkeaZBg4/1I3jSlaNe5+2JB4iybAkMOu77fG5zytLomTbzdhewsuBwpTVMJdga8T
IdPeIFCin1U+5SkbjSMlpKf+krE+5CyrNJ5jAgO9ATIqx66oCTYXfGlNapGRLfSE
sa0iMqCe/dr4GPU+flW2DZFWiyJVDSF1JjReQnfrWY+SD2SpP/lmlgltnY8MJngd
xBLG5nsZCpUXGB713Q8ZyIm2ThVAMiskcxBleIZDDghLuhGvY/9eFJhZpvOkjWa6
XGEi4E1G/SA+zVKFl41nHKCdqXdmIOnpcLlFBUVloQok5a95Kqc1TYw3f+WbdFff
99dAgk3gWwWZQA==
-----END CERTIFICATE-----

View File

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEAsraCori69OXEkA07Ykp2Ju+aHz33PqgKj0qbSbPm6ePh2mer
2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqqgM7uv9JoWCONcKo28cPRK7gH7up7
nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz5Bvi9fThcpf9M6qKdNUa2Wd24rM/
n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ZaEG86NXFpvvnLHRHiih/gXe7nby
1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtVSX4poNq9YrdP4XW9M0N7nocRU0p5
aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGTMwIDAQABAoIBAGh1m8hHWCg7gXh9
838RbRx3IswuKS27hWiaQEiFWmzOIb7KqDy1qAxtu+ayRY1paHegH6QY/+Kd824s
ibpzbgQacJ04/HrAVTVMmQ8Z2VLHoAN7lcPL1bd14aZGaLLZVtDeTDJ413grhxxv
4ho27gcgcbo4Z+rWgk7H2WRPCAGYqWYAycm3yF5vy9QaO6edU+T588YsEQOos5iy
5pVFSGDGZkcUp1ukL3BJYR+jvygn6WPCobQ/LScUdi+ucitaI9i+UdlLokZARVRG
M/msqcTM73thR8yVRcexU6NUDxRBfZ/f7moSAEbBmGDXuxDcIyH9KGMQ2rMtN1X3
lK8UNwkCgYEA2STJq/IUQHjdqd3Dqh/Q7Zm8/pMWFqLJSkqpnFtFsXPyUOx9zDOy
KqkIfGeyKwvsj9X9BcZ0FUKj9zoct1/WpPY+h7i7+z0MIujBh4AMjAcDrt4o76yK
UHuVmG2xKTdJoAbqOdToQeX6E82Ioal5pbB2W7AbCQScNBPZ52jxgtcCgYEA0rE7
2dFiRm0YmuszFYxft2+GP6NgP3R2TQNEooi1uCXG2xgwObie1YCHzpZ5CfSqJIxP
XB7DXpIWi7PxJoeai2F83LnmdFz6F1BPRobwDoSFNdaSKLg4Yf856zpgYNKhL1fE
OoOXj4VBWBZh1XDfZV44fgwlMIf7edOF1XOagwUCgYAw953O+7FbdKYwF0V3iOM5
oZDAK/UwN5eC/GFRVDfcM5RycVJRCVtlSWcTfuLr2C2Jpiz/72fgH34QU3eEVsV1
v94MBznFB1hESw7ReqvZq/9FoO3EVrl+OtBaZmosLD6bKtQJJJ0Xtz/01UW5hxla
pveZ55XBK9v51nwuNjk4UwKBgHD8fJUllSchUCWb5cwzeAz98Kdl7LJ6uQo5q2/i
EllLYOWThiEeIYdrIuklholRPIDXAaPsF2c6vn5yo+q+o6EFSZlw0+YpCjDAb5Lp
wAh5BprFk6HkkM/0t9Guf4rMyYWC8odSlE9x7YXYkuSMYDCTI4Zs6vCoq7I8PbQn
B4AlAoGAZ6Ee5m/ph5UVp/3+cR6jCY7aHBUU/M3pbJSkVjBW+ymEBVJ6sUdz8k3P
x8BiPEQggNN7faWBqRWP7KXPnDYHh6shYUgPJwI5HX6NE/ZDnnXjeysHRyf0oCo5
S6tHXwHNKB5HS1c/KDyyNGjP2oi/MF4o/MGWNWEcK6TJA3RGOYM=
-----END RSA PRIVATE KEY-----

View File

@ -0,0 +1,234 @@
package mysql
import (
"github.com/pingcap/check"
)
type mariaDBTestSuite struct {
}
var _ = check.Suite(&mariaDBTestSuite{})
func (t *mariaDBTestSuite) SetUpSuite(c *check.C) {
}
func (t *mariaDBTestSuite) TearDownSuite(c *check.C) {
}
func (t *mariaDBTestSuite) TestParseMariaDBGTID(c *check.C) {
cases := []struct {
gtidStr string
hashError bool
}{
{"0-1-1", false},
{"", false},
{"0-1-1-1", true},
{"1", true},
{"0-1-seq", true},
}
for _, cs := range cases {
gtid, err := ParseMariadbGTID(cs.gtidStr)
if cs.hashError {
c.Assert(err, check.NotNil)
} else {
c.Assert(err, check.IsNil)
c.Assert(gtid.String(), check.Equals, cs.gtidStr)
}
}
}
func (t *mariaDBTestSuite) TestMariaDBGTIDConatin(c *check.C) {
cases := []struct {
originGTIDStr, otherGTIDStr string
contain bool
}{
{"0-1-1", "0-1-2", false},
{"0-1-1", "", true},
{"2-1-1", "1-1-1", false},
{"1-2-1", "1-1-1", true},
{"1-2-2", "1-1-1", true},
}
for _, cs := range cases {
originGTID, err := ParseMariadbGTID(cs.originGTIDStr)
c.Assert(err, check.IsNil)
otherGTID, err := ParseMariadbGTID(cs.otherGTIDStr)
c.Assert(err, check.IsNil)
c.Assert(originGTID.Contain(otherGTID), check.Equals, cs.contain)
}
}
func (t *mariaDBTestSuite) TestMariaDBGTIDClone(c *check.C) {
gtid, err := ParseMariadbGTID("1-1-1")
c.Assert(err, check.IsNil)
clone := gtid.Clone()
c.Assert(gtid, check.DeepEquals, clone)
}
func (t *mariaDBTestSuite) TestMariaDBForward(c *check.C) {
cases := []struct {
currentGTIDStr, newerGTIDStr string
hashError bool
}{
{"0-1-1", "0-1-2", false},
{"0-1-1", "", false},
{"2-1-1", "1-1-1", true},
{"1-2-1", "1-1-1", false},
{"1-2-2", "1-1-1", false},
}
for _, cs := range cases {
currentGTID, err := ParseMariadbGTID(cs.currentGTIDStr)
c.Assert(err, check.IsNil)
newerGTID, err := ParseMariadbGTID(cs.newerGTIDStr)
c.Assert(err, check.IsNil)
err = currentGTID.forward(newerGTID)
if cs.hashError {
c.Assert(err, check.NotNil)
c.Assert(currentGTID.String(), check.Equals, cs.currentGTIDStr)
} else {
c.Assert(err, check.IsNil)
c.Assert(currentGTID.String(), check.Equals, cs.newerGTIDStr)
}
}
}
func (t *mariaDBTestSuite) TestParseMariaDBGTIDSet(c *check.C) {
cases := []struct {
gtidStr string
subGTIDs map[uint32]string //domain ID => gtid string
expectedStr []string // test String()
hasError bool
}{
{"0-1-1", map[uint32]string{0: "0-1-1"}, []string{"0-1-1"}, false},
{"", nil, []string{""}, false},
{"0-1-1,1-2-3", map[uint32]string{0: "0-1-1", 1: "1-2-3"}, []string{"0-1-1,1-2-3", "1-2-3,0-1-1"}, false},
{"0-1--1", nil, nil, true},
}
for _, cs := range cases {
gtidSet, err := ParseMariadbGTIDSet(cs.gtidStr)
if cs.hasError {
c.Assert(err, check.NotNil)
} else {
c.Assert(err, check.IsNil)
mariadbGTIDSet, ok := gtidSet.(*MariadbGTIDSet)
c.Assert(ok, check.IsTrue)
// check sub gtid
c.Assert(mariadbGTIDSet.Sets, check.HasLen, len(cs.subGTIDs))
for domainID, gtid := range mariadbGTIDSet.Sets {
c.Assert(mariadbGTIDSet.Sets, check.HasKey, domainID)
c.Assert(gtid.String(), check.Equals, cs.subGTIDs[domainID])
}
// check String() function
inExpectedResult := false
actualStr := mariadbGTIDSet.String()
for _, str := range cs.expectedStr {
if str == actualStr {
inExpectedResult = true
break
}
}
c.Assert(inExpectedResult, check.IsTrue)
}
}
}
func (t *mariaDBTestSuite) TestMariaDBGTIDSetUpdate(c *check.C) {
cases := []struct {
isNilGTID bool
gtidStr string
subGTIDs map[uint32]string
}{
{true, "", map[uint32]string{1: "1-1-1", 2: "2-2-2"}},
{false, "1-2-2", map[uint32]string{1: "1-2-2", 2: "2-2-2"}},
{false, "1-2-1", map[uint32]string{1: "1-2-1", 2: "2-2-2"}},
{false, "3-2-1", map[uint32]string{1: "1-1-1", 2: "2-2-2", 3: "3-2-1"}},
}
for _, cs := range cases {
gtidSet, err := ParseMariadbGTIDSet("1-1-1,2-2-2")
c.Assert(err, check.IsNil)
mariadbGTIDSet, ok := gtidSet.(*MariadbGTIDSet)
c.Assert(ok, check.IsTrue)
if cs.isNilGTID {
c.Assert(mariadbGTIDSet.AddSet(nil), check.IsNil)
} else {
err := gtidSet.Update(cs.gtidStr)
c.Assert(err, check.IsNil)
}
// check sub gtid
c.Assert(mariadbGTIDSet.Sets, check.HasLen, len(cs.subGTIDs))
for domainID, gtid := range mariadbGTIDSet.Sets {
c.Assert(mariadbGTIDSet.Sets, check.HasKey, domainID)
c.Assert(gtid.String(), check.Equals, cs.subGTIDs[domainID])
}
}
}
func (t *mariaDBTestSuite) TestMariaDBGTIDSetEqual(c *check.C) {
cases := []struct {
originGTIDStr, otherGTIDStr string
equals bool
}{
{"", "", true},
{"1-1-1", "1-1-1,2-2-2", false},
{"1-1-1,2-2-2", "1-1-1", false},
{"1-1-1,2-2-2", "1-1-1,2-2-2", true},
{"1-1-1,2-2-2", "1-1-1,2-2-3", false},
}
for _, cs := range cases {
originGTID, err := ParseMariadbGTIDSet(cs.originGTIDStr)
c.Assert(err, check.IsNil)
otherGTID, err := ParseMariadbGTIDSet(cs.otherGTIDStr)
c.Assert(err, check.IsNil)
c.Assert(originGTID.Equal(otherGTID), check.Equals, cs.equals)
}
}
func (t *mariaDBTestSuite) TestMariaDBGTIDSetContain(c *check.C) {
cases := []struct {
originGTIDStr, otherGTIDStr string
contain bool
}{
{"", "", true},
{"1-1-1", "1-1-1,2-2-2", false},
{"1-1-1,2-2-2", "1-1-1", true},
{"1-1-1,2-2-2", "1-1-1,2-2-2", true},
{"1-1-1,2-2-2", "1-1-1,2-2-1", true},
{"1-1-1,2-2-2", "1-1-1,2-2-3", false},
}
for _, cs := range cases {
originGTIDSet, err := ParseMariadbGTIDSet(cs.originGTIDStr)
c.Assert(err, check.IsNil)
otherGTIDSet, err := ParseMariadbGTIDSet(cs.otherGTIDStr)
c.Assert(err, check.IsNil)
c.Assert(originGTIDSet.Contain(otherGTIDSet), check.Equals, cs.contain)
}
}
func (t *mariaDBTestSuite) TestMariaDBGTIDSetClone(c *check.C) {
cases := []string{"", "1-1-1", "1-1-1,2-2-2"}
for _, str := range cases {
gtidSet, err := ParseMariadbGTIDSet(str)
c.Assert(err, check.IsNil)
c.Assert(gtidSet.Clone(), check.DeepEquals, gtidSet)
}
}

View File

@ -0,0 +1,50 @@
package replication
import (
"context"
"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/siddontang/go-mysql/mysql"
"os"
"sync"
"time"
)
func (t *testSyncerSuite) TestStartBackupEndInGivenTime(c *C) {
t.setupTest(c, mysql.MySQLFlavor)
t.testExecute(c, "RESET MASTER")
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
t.testSync(c, nil)
t.testExecute(c, "FLUSH LOGS")
t.testSync(c, nil)
}()
os.RemoveAll("./var")
timeout := 2 * time.Second
done := make(chan bool)
go func() {
err := t.b.StartBackup("./var", mysql.Position{Name: "", Pos: uint32(0)}, timeout)
c.Assert(err, IsNil)
done <- true
}()
failTimeout := 5 * timeout
ctx, _ := context.WithTimeout(context.Background(), failTimeout)
select {
case <-done:
return
case <-ctx.Done():
c.Assert(errors.New("time out error"), IsNil)
}
}

View File

@ -0,0 +1,49 @@
package replication
import (
"fmt"
"strings"
"time"
)
var (
fracTimeFormat []string
)
// fracTime is a help structure wrapping Golang Time.
type fracTime struct {
time.Time
// Dec must in [0, 6]
Dec int
timestampStringLocation *time.Location
}
func (t fracTime) String() string {
tt := t.Time
if t.timestampStringLocation != nil {
tt = tt.In(t.timestampStringLocation)
}
return tt.Format(fracTimeFormat[t.Dec])
}
func formatZeroTime(frac int, dec int) string {
if dec == 0 {
return "0000-00-00 00:00:00"
}
s := fmt.Sprintf("0000-00-00 00:00:00.%06d", frac)
// dec must < 6, if frac is 924000, but dec is 3, we must output 924 here.
return s[0 : len(s)-(6-dec)]
}
func init() {
fracTimeFormat = make([]string, 7)
fracTimeFormat[0] = "2006-01-02 15:04:05"
for i := 1; i <= 6; i++ {
fracTimeFormat[i] = fmt.Sprintf("2006-01-02 15:04:05.%s", strings.Repeat("0", i))
}
}

View File

@ -0,0 +1,70 @@
package replication
import (
"time"
. "github.com/pingcap/check"
)
type testTimeSuite struct{}
var _ = Suite(&testTimeSuite{})
func (s *testTimeSuite) TestTime(c *C) {
tbls := []struct {
year int
month int
day int
hour int
min int
sec int
microSec int
frac int
expected string
}{
{2000, 1, 1, 1, 1, 1, 1, 0, "2000-01-01 01:01:01"},
{2000, 1, 1, 1, 1, 1, 1, 1, "2000-01-01 01:01:01.0"},
{2000, 1, 1, 1, 1, 1, 1, 6, "2000-01-01 01:01:01.000001"},
}
for _, t := range tbls {
t1 := fracTime{time.Date(t.year, time.Month(t.month), t.day, t.hour, t.min, t.sec, t.microSec*1000, time.UTC), t.frac, nil}
c.Assert(t1.String(), Equals, t.expected)
}
zeroTbls := []struct {
frac int
dec int
expected string
}{
{0, 1, "0000-00-00 00:00:00.0"},
{1, 1, "0000-00-00 00:00:00.0"},
{123, 3, "0000-00-00 00:00:00.000"},
{123000, 3, "0000-00-00 00:00:00.123"},
{123, 6, "0000-00-00 00:00:00.000123"},
{123000, 6, "0000-00-00 00:00:00.123000"},
}
for _, t := range zeroTbls {
c.Assert(formatZeroTime(t.frac, t.dec), Equals, t.expected)
}
}
func (s *testTimeSuite) TestTimeStringLocation(c *C) {
t := fracTime{
time.Date(2018, time.Month(7), 30, 10, 0, 0, 0, time.FixedZone("EST", -5*3600)),
0,
nil,
}
c.Assert(t.String(), Equals, "2018-07-30 10:00:00")
t = fracTime{
time.Date(2018, time.Month(7), 30, 10, 0, 0, 0, time.FixedZone("EST", -5*3600)),
0,
time.UTC,
}
c.Assert(t.String(), Equals, "2018-07-30 15:00:00")
}
var _ = Suite(&testTimeSuite{})

View File

@ -0,0 +1,133 @@
package server
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"fmt"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
)
func (c *Conn) handleAuthSwitchResponse() error {
authData, err := c.readAuthSwitchRequestResponse()
if err != nil {
return err
}
switch c.authPluginName {
case AUTH_NATIVE_PASSWORD:
if err := c.acquirePassword(); err != nil {
return err
}
if !bytes.Equal(CalcPassword(c.salt, []byte(c.password)), authData) {
return ErrAccessDenied
}
return nil
case AUTH_CACHING_SHA2_PASSWORD:
if !c.cachingSha2FullAuth {
// Switched auth method but no MoreData packet send yet
if err := c.compareCacheSha2PasswordAuthData(authData); err != nil {
return err
} else {
if c.cachingSha2FullAuth {
return c.handleAuthSwitchResponse()
}
return nil
}
}
// AuthMoreData packet already sent, do full auth
if err := c.handleCachingSha2PasswordFullAuth(authData); err != nil {
return err
}
c.writeCachingSha2Cache()
return nil
case AUTH_SHA256_PASSWORD:
cont, err := c.handlePublicKeyRetrieval(authData)
if err != nil {
return err
}
if !cont {
return nil
}
if err := c.acquirePassword(); err != nil {
return err
}
return c.compareSha256PasswordAuthData(authData, c.password)
default:
return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName)
}
}
func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
if err := c.acquirePassword(); err != nil {
return err
}
if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok {
if !tlsConn.ConnectionState().HandshakeComplete {
return errors.New("incomplete TSL handshake")
}
// connection is SSL/TLS, client should send plain password
// deal with the trailing \NUL added for plain text password received
if l := len(authData); l != 0 && authData[l-1] == 0x00 {
authData = authData[:l-1]
}
if bytes.Equal(authData, []byte(c.password)) {
return nil
}
return ErrAccessDenied
} else {
// client either request for the public key or send the encrypted password
if len(authData) == 1 && authData[0] == 0x02 {
// send the public key
if err := c.writeAuthMoreDataPubkey(); err != nil {
return err
}
// read the encrypted password
var err error
if authData, err = c.readAuthSwitchRequestResponse(); err != nil {
return err
}
}
// the encrypted password
// decrypt
dbytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, (c.serverConf.tlsConfig.Certificates[0].PrivateKey).(*rsa.PrivateKey), authData, nil)
if err != nil {
return err
}
plain := make([]byte, len(c.password)+1)
copy(plain, c.password)
for i := range plain {
j := i % len(c.salt)
plain[i] ^= c.salt[j]
}
if bytes.Equal(plain, dbytes) {
return nil
}
return ErrAccessDenied
}
}
func (c *Conn) writeCachingSha2Cache() {
// write cache
if c.password == "" {
return
}
// SHA256(PASSWORD)
crypt := sha256.New()
crypt.Write([]byte(c.password))
m1 := crypt.Sum(nil)
// SHA256(SHA256(PASSWORD))
crypt.Reset()
crypt.Write(m1)
m2 := crypt.Sum(nil)
// caching_sha2_password will maintain an in-memory hash of `user`@`host` => SHA256(SHA256(PASSWORD))
c.serverConf.cacheShaPassword.Store(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr()), m2)
}

View File

@ -0,0 +1,233 @@
package server
import (
"database/sql"
"fmt"
"net"
"strings"
"sync"
"testing"
"time"
_ "github.com/go-sql-driver/mysql"
"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/test_util/test_keys"
)
var delay = 50
// test caching for 'caching_sha2_password'
// NOTE the idea here is to plugin a throttled credential provider so that the first connection (cache miss) will take longer time
// than the second connection (cache hit). Remember to set the password for MySQL user otherwise it won't cache empty password.
func TestCachingSha2Cache(t *testing.T) {
log.SetLevel(log.LevelDebug)
remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50}
remoteProvider.AddUser(*testUser, *testPassword)
cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf)
// no TLS
Suite(&cacheTestSuite{
server: cacheServer,
credProvider: remoteProvider,
tlsPara: "false",
})
TestingT(t)
}
func TestCachingSha2CacheTLS(t *testing.T) {
log.SetLevel(log.LevelDebug)
remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50}
remoteProvider.AddUser(*testUser, *testPassword)
cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf)
// TLS
Suite(&cacheTestSuite{
server: cacheServer,
credProvider: remoteProvider,
tlsPara: "skip-verify",
})
TestingT(t)
}
type RemoteThrottleProvider struct {
*InMemoryProvider
delay int // in milliseconds
}
func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) {
time.Sleep(time.Millisecond * time.Duration(m.delay))
return m.InMemoryProvider.GetCredential(username)
}
type cacheTestSuite struct {
server *Server
credProvider CredentialProvider
tlsPara string
db *sql.DB
l net.Listener
}
func (s *cacheTestSuite) SetUpSuite(c *C) {
var err error
s.l, err = net.Listen("tcp", *testAddr)
c.Assert(err, IsNil)
go s.onAccept(c)
time.Sleep(30 * time.Millisecond)
}
func (s *cacheTestSuite) TearDownSuite(c *C) {
if s.l != nil {
s.l.Close()
}
}
func (s *cacheTestSuite) onAccept(c *C) {
for {
conn, err := s.l.Accept()
if err != nil {
return
}
go s.onConn(conn, c)
}
}
func (s *cacheTestSuite) onConn(conn net.Conn, c *C) {
//co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s})
co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testCacheHandler{s})
c.Assert(err, IsNil)
for {
err = co.HandleCommand()
if err != nil {
return
}
}
}
func (s *cacheTestSuite) runSelect(c *C) {
var a int64
var b string
err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=1").Scan(&a, &b)
c.Assert(err, IsNil)
c.Assert(a, Equals, int64(1))
c.Assert(b, Equals, "hello world")
}
func (s *cacheTestSuite) TestCache(c *C) {
// first connection
t1 := time.Now()
var err error
s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara))
c.Assert(err, IsNil)
s.db.SetMaxIdleConns(4)
s.runSelect(c)
t2 := time.Now()
d1 := int(t2.Sub(t1).Nanoseconds() / 1e6)
//log.Debugf("first connection took %d milliseconds", d1)
c.Assert(d1, GreaterEqual, delay)
if s.db != nil {
s.db.Close()
}
// second connection
t3 := time.Now()
s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara))
c.Assert(err, IsNil)
s.db.SetMaxIdleConns(4)
s.runSelect(c)
t4 := time.Now()
d2 := int(t4.Sub(t3).Nanoseconds() / 1e6)
//log.Debugf("second connection took %d milliseconds", d2)
c.Assert(d2, Less, delay)
if s.db != nil {
s.db.Close()
}
s.server.cacheShaPassword = &sync.Map{}
}
type testCacheHandler struct {
s *cacheTestSuite
}
func (h *testCacheHandler) UseDB(dbName string) error {
return nil
}
func (h *testCacheHandler) handleQuery(query string, binary bool) (*mysql.Result, error) {
ss := strings.Split(query, " ")
switch strings.ToLower(ss[0]) {
case "select":
var r *mysql.Resultset
var err error
//for handle go mysql driver select @@max_allowed_packet
if strings.Contains(strings.ToLower(query), "max_allowed_packet") {
r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{
{mysql.MaxPayloadLen},
}, binary)
} else {
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
{1, "hello world"},
}, binary)
}
if err != nil {
return nil, errors.Trace(err)
} else {
return &mysql.Result{0, 0, 0, r}, nil
}
case "insert":
return &mysql.Result{0, 1, 0, nil}, nil
case "delete":
return &mysql.Result{0, 0, 1, nil}, nil
case "update":
return &mysql.Result{0, 0, 1, nil}, nil
case "replace":
return &mysql.Result{0, 0, 1, nil}, nil
default:
return nil, fmt.Errorf("invalid query %s", query)
}
return nil, nil
}
func (h *testCacheHandler) HandleQuery(query string) (*mysql.Result, error) {
return h.handleQuery(query, false)
}
func (h *testCacheHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {
return nil, nil
}
func (h *testCacheHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) {
return 0, 0, nil, nil
}
func (h *testCacheHandler) HandleStmtClose(context interface{}) error {
return nil
}
func (h *testCacheHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) {
return h.handleQuery(query, true)
}
func (h *testCacheHandler) HandleOtherCommand(cmd byte, data []byte) error {
return mysql.NewError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("command %d is not supported now", cmd))
}

View File

@ -0,0 +1,4 @@
package server
// Ensure EmptyHandler implements Handler interface or cause compile time error
var _ Handler = EmptyHandler{}

View File

@ -0,0 +1,45 @@
package server
import "sync"
// interface for user credential provider
// hint: can be extended for more functionality
// =================================IMPORTANT NOTE===============================
// if the password in a third-party credential provider could be updated at runtime, we have to invalidate the caching
// for 'caching_sha2_password' by calling 'func (s *Server)InvalidateCache(string, string)'.
type CredentialProvider interface {
// check if the user exists
CheckUsername(username string) (bool, error)
// get user credential
GetCredential(username string) (password string, found bool, err error)
}
func NewInMemoryProvider() *InMemoryProvider {
return &InMemoryProvider{
userPool: sync.Map{},
}
}
// implements a in memory credential provider
type InMemoryProvider struct {
userPool sync.Map // username -> password
}
func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error) {
_, ok := m.userPool.Load(username)
return ok, nil
}
func (m *InMemoryProvider) GetCredential(username string) (password string, found bool, err error) {
v, ok := m.userPool.Load(username)
if !ok {
return "", false, nil
}
return v.(string), true, nil
}
func (m *InMemoryProvider) AddUser(username, password string) {
m.userPool.Store(username, password)
}
type Provider InMemoryProvider

View File

@ -0,0 +1,51 @@
package main
import (
"net"
"github.com/siddontang/go-log/log"
"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/server"
"github.com/siddontang/go-mysql/test_util/test_keys"
"crypto/tls"
"time"
)
type RemoteThrottleProvider struct {
*server.InMemoryProvider
delay int // in milliseconds
}
func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) {
time.Sleep(time.Millisecond * time.Duration(m.delay))
return m.InMemoryProvider.GetCredential(username)
}
func main() {
l, _ := net.Listen("tcp", "127.0.0.1:3306")
// user either the in-memory credential provider or the remote credential provider (you can implement your own)
//inMemProvider := server.NewInMemoryProvider()
//inMemProvider.AddUser("root", "123")
remoteProvider := &RemoteThrottleProvider{server.NewInMemoryProvider(), 10 + 50}
remoteProvider.AddUser("root", "123")
var tlsConf = server.NewServerTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, tls.VerifyClientCertIfGiven)
for {
c, _ := l.Accept()
go func() {
// Create a connection with user root and an empty password.
// You can use your own handler to handle command here.
svr := server.NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf)
conn, err := server.NewCustomizedConn(c, svr, remoteProvider, server.EmptyHandler{})
if err != nil {
log.Errorf("Connection error: %v", err)
return
}
for {
conn.HandleCommand()
}
}()
}
}

View File

@ -0,0 +1,190 @@
package server
import (
"bytes"
"crypto/tls"
"encoding/binary"
"github.com/juju/errors"
. "github.com/siddontang/go-mysql/mysql"
)
func (c *Conn) readHandshakeResponse() error {
data, pos, err := c.readFirstPart()
if err != nil {
return err
}
if pos, err = c.readUserName(data, pos); err != nil {
return err
}
authData, authLen, pos, err := c.readAuthData(data, pos)
if err != nil {
return err
}
pos += authLen
if pos, err = c.readDb(data, pos); err != nil {
return err
}
pos = c.readPluginName(data, pos)
cont, err := c.handleAuthMatch(authData, pos)
if err != nil {
return err
}
if !cont {
return nil
}
// ignore connect attrs for now, the proxy does not support passing attrs to actual MySQL server
// try to authenticate the client
return c.compareAuthData(c.authPluginName, authData)
}
func (c *Conn) readFirstPart() ([]byte, int, error) {
data, err := c.ReadPacket()
if err != nil {
return nil, 0, err
}
pos := 0
// check CLIENT_PROTOCOL_41
if uint32(binary.LittleEndian.Uint16(data[:2]))&CLIENT_PROTOCOL_41 == 0 {
return nil, 0, errors.New("CLIENT_PROTOCOL_41 compatible client is required")
}
//capability
c.capability = binary.LittleEndian.Uint32(data[:4])
if c.capability&CLIENT_SECURE_CONNECTION == 0 {
return nil, 0, errors.New("CLIENT_SECURE_CONNECTION compatible client is required")
}
pos += 4
//skip max packet size
pos += 4
//charset, skip, if you want to use another charset, use set names
//c.collation = CollationId(data[pos])
pos++
//skip reserved 23[00]
pos += 23
// is this a SSLRequest packet?
if len(data) == (4 + 4 + 1 + 23) {
if c.serverConf.capability&CLIENT_SSL == 0 {
return nil, 0, errors.Errorf("The host '%s' does not support SSL connections", c.RemoteAddr().String())
}
// switch to TLS
tlsConn := tls.Server(c.Conn.Conn, c.serverConf.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return nil, 0, err
}
c.Conn.Conn = tlsConn
// mysql handshake again
return c.readFirstPart()
}
return data, pos, nil
}
func (c *Conn) readUserName(data []byte, pos int) (int, error) {
//user name
user := string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
pos += len(user) + 1
c.user = user
return pos, nil
}
func (c *Conn) readDb(data []byte, pos int) (int, error) {
if c.capability&CLIENT_CONNECT_WITH_DB != 0 {
if len(data[pos:]) == 0 {
return pos, nil
}
db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
pos += len(db) + 1
if err := c.h.UseDB(db); err != nil {
return 0, err
}
}
return pos, nil
}
func (c *Conn) readPluginName(data []byte, pos int) int {
if c.capability&CLIENT_PLUGIN_AUTH != 0 {
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
pos += len(c.authPluginName)
} else {
// The method used is Native Authentication if both CLIENT_PROTOCOL_41 and CLIENT_SECURE_CONNECTION are set,
// but CLIENT_PLUGIN_AUTH is not set, so we fallback to 'mysql_native_password'
c.authPluginName = AUTH_NATIVE_PASSWORD
}
return pos
}
func (c *Conn) readAuthData(data []byte, pos int) ([]byte, int, int, error) {
// length encoded data
var auth []byte
var authLen int
if c.capability&CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
authData, isNULL, readBytes, err := LengthEncodedString(data[pos:])
if err != nil {
return nil, 0, 0, err
}
if isNULL {
// no auth length and no auth data, just \NUL, considered invalid auth data, and reject connection as MySQL does
return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.LocalAddr().String(), c.user, "Yes")
}
auth = authData
authLen = readBytes
} else {
//auth length and auth
authLen = int(data[pos])
pos++
auth = data[pos : pos+authLen]
if authLen == 0 {
// skip the next \NUL in case the password is empty
pos++
}
}
return auth, authLen, pos, nil
}
// Public Key Retrieval
// See: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html
func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) {
// if the client use 'sha256_password' auth method, and request for a public key
// we send back a keyfile with Protocol::AuthMoreData
if c.authPluginName == AUTH_SHA256_PASSWORD && len(authData) == 1 && authData[0] == 0x01 {
if c.serverConf.capability&CLIENT_SSL == 0 {
return false, errors.New("server does not support SSL: CLIENT_SSL not enabled")
}
if err := c.writeAuthMoreDataPubkey(); err != nil {
return false, err
}
return false, c.handleAuthSwitchResponse()
}
return true, nil
}
func (c *Conn) handleAuthMatch(authData []byte, pos int) (bool, error) {
// if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet
// to the client to ask the client to switch.
if c.authPluginName != c.serverConf.defaultAuthMethod {
if err := c.writeAuthSwitchRequest(c.serverConf.defaultAuthMethod); err != nil {
return false, err
}
c.authPluginName = c.serverConf.defaultAuthMethod
// handle AuthSwitchResponse
return false, c.handleAuthSwitchResponse()
}
return true, nil
}

View File

@ -0,0 +1,57 @@
package server
// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
func (c *Conn) writeInitialHandshake() error {
data := make([]byte, 4)
//min version 10
data = append(data, 10)
//server version[00]
data = append(data, c.serverConf.serverVersion...)
data = append(data, 0x00)
//connection id
data = append(data, byte(c.connectionID), byte(c.connectionID>>8), byte(c.connectionID>>16), byte(c.connectionID>>24))
//auth-plugin-data-part-1
data = append(data, c.salt[0:8]...)
//filter 0x00 byte, terminating the first part of a scramble
data = append(data, 0x00)
defaultFlag := c.serverConf.capability
//capability flag lower 2 bytes, using default capability here
data = append(data, byte(defaultFlag), byte(defaultFlag>>8))
//charset
data = append(data, c.serverConf.collationId)
//status
data = append(data, byte(c.status), byte(c.status>>8))
//capability flag upper 2 bytes, using default capability here
data = append(data, byte(defaultFlag>>16), byte(defaultFlag>>24))
// server supports CLIENT_PLUGIN_AUTH and CLIENT_SECURE_CONNECTION
data = append(data, byte(8+12+1))
//reserved 10 [00]
data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
//auth-plugin-data-part-2
data = append(data, c.salt[8:]...)
// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
// add \NUL to terminate the string
data = append(data, 0x00)
// auth plugin name
data = append(data, c.serverConf.defaultAuthMethod...)
// EOF if MySQL version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
// \NUL otherwise, so we use \NUL
data = append(data, 0)
return c.WritePacket(data)
}

View File

@ -0,0 +1,103 @@
package server
import (
"crypto/tls"
"fmt"
"sync"
. "github.com/siddontang/go-mysql/mysql"
)
var defaultServer = NewDefaultServer()
// Defines a basic MySQL server with configs.
//
// We do not aim at implementing the whole MySQL connection suite to have the best compatibilities for the clients.
// The MySQL server can be configured to switch auth methods covering 'mysql_old_password', 'mysql_native_password',
// 'mysql_clear_password', 'authentication_windows_client', 'sha256_password', 'caching_sha2_password', etc.
//
// However, since some old auth methods are considered broken with security issues. MySQL major versions like 5.7 and 8.0 default to
// 'mysql_native_password' or 'caching_sha2_password', and most MySQL clients should have already supported at least one of the three auth
// methods 'mysql_native_password', 'caching_sha2_password', and 'sha256_password'. Thus here we will only support these three
// auth methods, and use 'mysql_native_password' as default for maximum compatibility with the clients and leave the other two as
// config options.
//
// The MySQL doc states that 'mysql_old_password' will be used if 'CLIENT_PROTOCOL_41' or 'CLIENT_SECURE_CONNECTION' flag is not set.
// We choose to drop the support for insecure 'mysql_old_password' auth method and require client capability 'CLIENT_PROTOCOL_41' and 'CLIENT_SECURE_CONNECTION'
// are set. Besides, if 'CLIENT_PLUGIN_AUTH' is not set, we fallback to 'mysql_native_password' auth method.
type Server struct {
serverVersion string // e.g. "8.0.12"
protocolVersion int // minimal 10
capability uint32 // server capability flag
collationId uint8
defaultAuthMethod string // default authentication method, 'mysql_native_password'
pubKey []byte
tlsConfig *tls.Config
cacheShaPassword *sync.Map // 'user@host' -> SHA256(SHA256(PASSWORD))
}
// New mysql server with default settings.
//
// NOTES:
// TLS support will be enabled by default with auto-generated CA and server certificates (however, you can still use
// non-TLS connection). By default, it will verify the client certificate if present. You can enable TLS support on
// the client side without providing a client-side certificate. So only when you need the server to verify client
// identity for maximum security, you need to set a signed certificate for the client.
func NewDefaultServer() *Server {
caPem, caKey := generateCA()
certPem, keyPem := generateAndSignRSACerts(caPem, caKey)
tlsConf := NewServerTLSConfig(caPem, certPem, keyPem, tls.VerifyClientCertIfGiven)
return &Server{
serverVersion: "5.7.0",
protocolVersion: 10,
capability: CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 |
CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_SSL | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
collationId: DEFAULT_COLLATION_ID,
defaultAuthMethod: AUTH_NATIVE_PASSWORD,
pubKey: getPublicKeyFromCert(certPem),
tlsConfig: tlsConf,
cacheShaPassword: new(sync.Map),
}
}
// New mysql server with customized settings.
//
// NOTES:
// You can control the authentication methods and TLS settings here.
// For auth method, you can specify one of the supported methods 'mysql_native_password', 'caching_sha2_password', and 'sha256_password'.
// The specified auth method will be enforced by the server in the connection phase. That means, client will be asked to switch auth method
// if the supplied auth method is different from the server default.
// And for TLS support, you can specify self-signed or CA-signed certificates and decide whether the client needs to provide
// a signed or unsigned certificate to provide different level of security.
func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config) *Server {
if !isAuthMethodSupported(defaultAuthMethod) {
panic(fmt.Sprintf("server authentication method '%s' is not supported", defaultAuthMethod))
}
//if !isAuthMethodAllowedByServer(defaultAuthMethod, allowedAuthMethods) {
// panic(fmt.Sprintf("default auth method is not one of the allowed auth methods"))
//}
var capFlag = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 |
CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
if tlsConfig != nil {
capFlag |= CLIENT_SSL
}
return &Server{
serverVersion: serverVersion,
protocolVersion: 10,
capability: capFlag,
collationId: collationId,
defaultAuthMethod: defaultAuthMethod,
pubKey: pubKey,
tlsConfig: tlsConfig,
cacheShaPassword: new(sync.Map),
}
}
func isAuthMethodSupported(authMethod string) bool {
return authMethod == AUTH_NATIVE_PASSWORD || authMethod == AUTH_CACHING_SHA2_PASSWORD || authMethod == AUTH_SHA256_PASSWORD
}
func (s *Server) InvalidateCache(username string, host string) {
s.cacheShaPassword.Delete(fmt.Sprintf("%s@%s", username, host))
}

133
vendor/github.com/siddontang/go-mysql/server/ssl.go generated vendored Normal file
View File

@ -0,0 +1,133 @@
package server
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"time"
)
// generate TLS config for server side
// controlling the security level by authType
func NewServerTLSConfig(caPem, certPem, keyPem []byte, authType tls.ClientAuthType) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caPem) {
panic("failed to add ca PEM")
}
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
config := &tls.Config{
ClientAuth: authType,
Certificates: []tls.Certificate{cert},
ClientCAs: pool,
}
return config
}
// extract RSA public key from certificate
func getPublicKeyFromCert(certPem []byte) []byte {
block, _ := pem.Decode(certPem)
crt, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic(err)
}
pubKey, err := x509.MarshalPKIXPublicKey(crt.PublicKey.(*rsa.PublicKey))
if err != nil {
panic(err)
}
return pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubKey})
}
// generate and sign RSA certificates with given CA
// see: https://fale.io/blog/2017/06/05/create-a-pki-in-golang/
func generateAndSignRSACerts(caPem, caKey []byte) ([]byte, []byte) {
// Load CA
catls, err := tls.X509KeyPair(caPem, caKey)
if err != nil {
panic(err)
}
ca, err := x509.ParseCertificate(catls.Certificate[0])
if err != nil {
panic(err)
}
// use the CA to sign certificates
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
panic(err)
}
cert := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"ORGANIZATION_NAME"},
Country: []string{"COUNTRY_CODE"},
Province: []string{"PROVINCE"},
Locality: []string{"CITY"},
StreetAddress: []string{"ADDRESS"},
PostalCode: []string{"POSTAL_CODE"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
// sign the certificate
cert_b, err := x509.CreateCertificate(rand.Reader, ca, cert, &priv.PublicKey, catls.PrivateKey)
if err != nil {
panic(err)
}
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert_b})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
return certPem, keyPem
}
// generate CA in PEM
// see: https://github.com/golang/go/blob/master/src/crypto/tls/generate_cert.go
func generateCA() ([]byte, []byte) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
panic(err)
}
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"ORGANIZATION_NAME"},
Country: []string{"COUNTRY_CODE"},
Province: []string{"PROVINCE"},
Locality: []string{"CITY"},
StreetAddress: []string{"ADDRESS"},
PostalCode: []string{"POSTAL_CODE"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
}
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
panic(err)
}
caPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
caKey := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
return caPem, caKey
}

View File

@ -0,0 +1,85 @@
package test_keys
// here we put the testing encryption keys here
// NOTE THIS IS FOR TESTING ONLY, DO NOT USE THEM IN PRODUCTION!
var PubPem = []byte(`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsraCori69OXEkA07Ykp2
Ju+aHz33PqgKj0qbSbPm6ePh2mer2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqq
gM7uv9JoWCONcKo28cPRK7gH7up7nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz
5Bvi9fThcpf9M6qKdNUa2Wd24rM/n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/
ZaEG86NXFpvvnLHRHiih/gXe7nby1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtV
SX4poNq9YrdP4XW9M0N7nocRU0p5aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGT
MwIDAQAB
-----END PUBLIC KEY-----`)
var CertPem = []byte(`-----BEGIN CERTIFICATE-----
MIIDBjCCAe4CCQDg06wCf7hcuDANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB
VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0
cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDUyNVoXDTI4MDgxNjA4NDUyNVowRTELMAkG
A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0
IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
ALK2gqK4uvTlxJANO2JKdibvmh899z6oCo9Km0mz5unj4dpnq9hljsQuKtcHUcM4
HXcE06knaJ4TOF7lcsjaqoDO7r/SaFgjjXCqNvHD0Su4B+7qe52BZZTRV1AANP10
PvebarXSEzaZUCyHHhSF8+Qb4vX04XKX/TOqinTVGtlnduKzP5+qsaFBtpLAw1V0
At9EQB5BgnTYtdIsmvD4/2WhBvOjVxab75yx0R4oof4F3u528tbEegcWhBtmy2Xd
HI3S+TLljj3kOOdB+pgrVUl+KaDavWK3T+F1vTNDe56HEVNKeWlLy1scul61E0j9
IkZAu6aRDxtKdl7bKu0BkzMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAma3yFqR7
xkeaZBg4/1I3jSlaNe5+2JB4iybAkMOu77fG5zytLomTbzdhewsuBwpTVMJdga8T
IdPeIFCin1U+5SkbjSMlpKf+krE+5CyrNJ5jAgO9ATIqx66oCTYXfGlNapGRLfSE
sa0iMqCe/dr4GPU+flW2DZFWiyJVDSF1JjReQnfrWY+SD2SpP/lmlgltnY8MJngd
xBLG5nsZCpUXGB713Q8ZyIm2ThVAMiskcxBleIZDDghLuhGvY/9eFJhZpvOkjWa6
XGEi4E1G/SA+zVKFl41nHKCdqXdmIOnpcLlFBUVloQok5a95Kqc1TYw3f+WbdFff
99dAgk3gWwWZQA==
-----END CERTIFICATE-----`)
var KeyPem = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEAsraCori69OXEkA07Ykp2Ju+aHz33PqgKj0qbSbPm6ePh2mer
2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqqgM7uv9JoWCONcKo28cPRK7gH7up7
nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz5Bvi9fThcpf9M6qKdNUa2Wd24rM/
n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ZaEG86NXFpvvnLHRHiih/gXe7nby
1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtVSX4poNq9YrdP4XW9M0N7nocRU0p5
aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGTMwIDAQABAoIBAGh1m8hHWCg7gXh9
838RbRx3IswuKS27hWiaQEiFWmzOIb7KqDy1qAxtu+ayRY1paHegH6QY/+Kd824s
ibpzbgQacJ04/HrAVTVMmQ8Z2VLHoAN7lcPL1bd14aZGaLLZVtDeTDJ413grhxxv
4ho27gcgcbo4Z+rWgk7H2WRPCAGYqWYAycm3yF5vy9QaO6edU+T588YsEQOos5iy
5pVFSGDGZkcUp1ukL3BJYR+jvygn6WPCobQ/LScUdi+ucitaI9i+UdlLokZARVRG
M/msqcTM73thR8yVRcexU6NUDxRBfZ/f7moSAEbBmGDXuxDcIyH9KGMQ2rMtN1X3
lK8UNwkCgYEA2STJq/IUQHjdqd3Dqh/Q7Zm8/pMWFqLJSkqpnFtFsXPyUOx9zDOy
KqkIfGeyKwvsj9X9BcZ0FUKj9zoct1/WpPY+h7i7+z0MIujBh4AMjAcDrt4o76yK
UHuVmG2xKTdJoAbqOdToQeX6E82Ioal5pbB2W7AbCQScNBPZ52jxgtcCgYEA0rE7
2dFiRm0YmuszFYxft2+GP6NgP3R2TQNEooi1uCXG2xgwObie1YCHzpZ5CfSqJIxP
XB7DXpIWi7PxJoeai2F83LnmdFz6F1BPRobwDoSFNdaSKLg4Yf856zpgYNKhL1fE
OoOXj4VBWBZh1XDfZV44fgwlMIf7edOF1XOagwUCgYAw953O+7FbdKYwF0V3iOM5
oZDAK/UwN5eC/GFRVDfcM5RycVJRCVtlSWcTfuLr2C2Jpiz/72fgH34QU3eEVsV1
v94MBznFB1hESw7ReqvZq/9FoO3EVrl+OtBaZmosLD6bKtQJJJ0Xtz/01UW5hxla
pveZ55XBK9v51nwuNjk4UwKBgHD8fJUllSchUCWb5cwzeAz98Kdl7LJ6uQo5q2/i
EllLYOWThiEeIYdrIuklholRPIDXAaPsF2c6vn5yo+q+o6EFSZlw0+YpCjDAb5Lp
wAh5BprFk6HkkM/0t9Guf4rMyYWC8odSlE9x7YXYkuSMYDCTI4Zs6vCoq7I8PbQn
B4AlAoGAZ6Ee5m/ph5UVp/3+cR6jCY7aHBUU/M3pbJSkVjBW+ymEBVJ6sUdz8k3P
x8BiPEQggNN7faWBqRWP7KXPnDYHh6shYUgPJwI5HX6NE/ZDnnXjeysHRyf0oCo5
S6tHXwHNKB5HS1c/KDyyNGjP2oi/MF4o/MGWNWEcK6TJA3RGOYM=
-----END RSA PRIVATE KEY-----`)
var CaPem = []byte(`-----BEGIN CERTIFICATE-----
MIIDtTCCAp2gAwIBAgIJANeS1FOzWXlZMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTgwODE2MTUxNDE5WhcNMjEwNjA1MTUxNDE5WjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
CgKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QLsuT1fGf5
zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eoxrS64GSb
PVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKvlpEAtRAe
4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7IeXUc456
yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioONaePQKFR
/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABo4GnMIGkMB0GA1UdDgQWBBQgHiwD
00upIbCOunlK4HRw89DhjjB1BgNVHSMEbjBsgBQgHiwD00upIbCOunlK4HRw89Dh
jqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV
BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJANeS1FOzWXlZMAwGA1UdEwQF
MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAFMZFQTFKU5tWIpWh8BbVZeVZcng0Kiq
qwbhVwaTkqtfmbqw8/w+faOWylmLncQEMmgvnUltGMQlQKBwQM2byzPkz9phal3g
uI0JWJYqtcMyIQUB9QbbhrDNC9kdt/ji/x6rrIqzaMRuiBXqH5LQ9h856yXzArqd
cAQGzzYpbUCIv7ciSB93cKkU73fQLZVy5ZBy1+oAa1V9U4cb4G/20/PDmT+G3Gxz
pEjeDKtz8XINoWgA2cSdfAhNZt5vqJaCIZ8qN0z6C7SUKwUBderERUMLUXdhUldC
KTVHyEPvd0aULd5S5vEpKCnHcQmFcLdoN8t9k9pR9ZgwqXbyJHlxWFo=
-----END CERTIFICATE-----`)

View File

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -0,0 +1,509 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"math"
"reflect"
"strings"
"time"
)
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
return err
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
}
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, indirect(rv))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
copy(context, md.context)
rv.Set(reflect.ValueOf(Primitive{
undecoded: data,
context: context,
}))
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
return md.unifyMap(data, rv)
case reflect.Array:
return md.unifyArray(data, rv)
case reflect.Slice:
return md.unifySlice(data, rv)
case reflect.String:
return md.unifyString(data, rv)
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("unsupported type %s", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
if mapping == nil {
return nil
}
return e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
}
for key, datum := range tmap {
var f *field
fields := cachedTypeFields(rv.Type())
for i := range fields {
ff := &fields[i]
if ff.name == key {
f = ff
break
}
if f == nil && strings.EqualFold(ff.name, key) {
f = ff
}
}
if f != nil {
subv := rv
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("cannot write unexported field %s.%s",
rv.Type().String(), f.name)
}
}
}
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
if tmap == nil {
return nil
}
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
}
n := datav.Len()
if rv.IsNil() || rv.Cap() < n {
rv.Set(reflect.MakeSlice(rv.Type(), n, n))
}
rv.SetLen(n)
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("value %d is out of range for int8", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("value %d is out of range for int16", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("value %d is out of range for int32", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("value %d is out of range for uint8", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("value %d is out of range for uint16", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("value %d is out of range for uint32", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
}
return nil
}
return badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
}
s = string(text)
case fmt.Stringer:
s = sdata.String()
case string:
s = sdata
case bool:
s = fmt.Sprintf("%v", sdata)
case int64:
s = fmt.Sprintf("%d", sdata)
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanSet() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
}
}
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("cannot load TOML value of type %T into a Go %s", data, expected)
}

View File

@ -0,0 +1,121 @@
package toml
import "strings"
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
}
return k[i]
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
return md.keys
}
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
// and so no keys will be considered decoded.
//
// In this sense, the Undecoded keys correspond to keys in the TOML document
// that do not have a concrete type in your representation.
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
undecoded = append(undecoded, key)
}
}
return undecoded
}

View File

@ -0,0 +1,27 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/toml-lang/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
*/
package toml

View File

@ -0,0 +1,568 @@
package toml
import (
"bufio"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"toml: cannot encode array with mixed element types")
errArrayNilElement = errors.New(
"toml: cannot encode array with nil element")
errNonString = errors.New(
"toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New(
"toml: cannot encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"toml: TOML array element cannot contain a table")
errNoKey = errors.New(
"toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
if terr, ok := r.(tomlEncodeError); ok {
err = terr.error
return
}
panic(r)
}
}()
enc.encode(key, rv)
return nil
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
return
}
k := rv.Kind()
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
}
case reflect.Interface:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Map:
if rv.IsNil() {
return
}
enc.eTable(key, rv)
case reflect.Ptr:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("unexpected primitive type: %s", rv.Kind()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
}
return fstr
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
}
}
enc.wf("]")
}
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra newline between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
case reflect.Struct:
enc.eStruct(key, rv)
default:
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
}
// Sort keys so that we have deterministic output. And write keys directly
// underneath this key first, before writing sub-structs or sub-maps.
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
continue
}
enc.encode(key.add(mapKey), mrv)
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexported fields
if f.PkgPath != "" && !f.Anonymous {
continue
}
frv := rv.Field(i)
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
// Treat anonymous struct fields with
// tag names as though they are not
// anonymous, like encoding/json does.
if getOptions(f.Tag).name == "" {
addFields(t, frv, f.Index)
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct &&
getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), f.Index)
}
continue
}
// Fall through to the normal field encoding logic below
// for non-struct anonymous fields.
}
}
if typeIsHash(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
continue
}
opts := getOptions(sft.Tag)
if opts.skip {
continue
}
keyName := sft.Name
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && isEmpty(sf) {
continue
}
if opts.omitzero && isZero(sf) {
continue
}
enc.encode(key.add(keyName), sf)
}
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInteger
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
}
return tomlArray
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
}
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
type tagOptions struct {
skip bool // "-"
name string
omitempty bool
omitzero bool
}
func getOptions(tag reflect.StructTag) tagOptions {
t := tag.Get("toml")
if t == "-" {
return tagOptions{skip: true}
}
var opts tagOptions
parts := strings.Split(t, ",")
opts.name = parts[0]
for _, s := range parts[1:] {
switch s {
case "omitempty":
opts.omitempty = true
case "omitzero":
opts.omitzero = true
}
}
return opts
}
func isZero(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return rv.Uint() == 0
case reflect.Float32, reflect.Float64:
return rv.Float() == 0.0
}
return false
}
func isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return rv.Len() == 0
case reflect.Bool:
return !rv.Bool()
}
return false
}
func (enc *Encoder) newline() {
if enc.hasWritten {
enc.wf("\n")
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
encPanic(err)
}
enc.hasWritten = true
}
func (enc *Encoder) indentStr(key Key) string {
return strings.Repeat(enc.Indent, len(key)-1)
}
func encPanic(err error) {
panic(tomlEncodeError{err})
}
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
return v
}
}
func isNil(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

View File

@ -0,0 +1,19 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

View File

@ -0,0 +1,18 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

View File

@ -0,0 +1,953 @@
package toml
import (
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemRawString
itemMultilineString
itemRawMultilineString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
itemInlineTableStart
itemInlineTableEnd
)
const (
eof = 0
comma = ','
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
inlineTableStart = '{'
inlineTableEnd = '}'
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
line int
state stateFn
items chan item
// Allow for backing up up to three runes.
// This is necessary because TOML contains 3-rune tokens (""" and ''').
prevWidths [3]int
nprev int // how many of prevWidths are in use
// If we emit an eof, we can still back up, but it is not OK to call
// next again.
atEOF bool
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
}
func lex(input string) *lexer {
lx := &lexer{
input: input,
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) emitTrim(typ itemType) {
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.atEOF {
panic("next called after EOF")
}
if lx.pos >= len(lx.input) {
lx.atEOF = true
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
lx.prevWidths[2] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[0]
if lx.nprev < 3 {
lx.nprev++
}
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.prevWidths[0] = w
lx.pos += w
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only twice between calls to next.
func (lx *lexer) backup() {
if lx.atEOF {
lx.atEOF = false
return
}
if lx.nprev < 1 {
panic("backed up too far")
}
w := lx.prevWidths[0]
lx.prevWidths[0] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[2]
lx.nprev--
lx.pos -= w
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// skip ignores all input that matches the given predicate.
func (lx *lexer) skip(pred func(rune) bool) {
for {
r := lx.next()
if pred(r) {
continue
}
lx.backup()
lx.ignore()
return
}
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (newlines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("unexpected EOF")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a newline. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a newline for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.emit(itemEOF)
return nil
}
return lx.errorf("expected a top-level item to end with a newline, "+
"comment, or EOF, but got %q instead", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("expected end of table array name delimiter %q, "+
"but got %q instead", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("unexpected end of table name " +
"(table names cannot be empty)")
case r == tableSep:
return lx.errorf("unexpected table separator " +
"(table names cannot be empty)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
return lexValue // reuse string lexing
default:
return lexBareTableName
}
}
// lexBareTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
r := lx.next()
if isBareKeyChar(r) {
return lexBareTableName
}
lx.backup()
lx.emit(itemText)
return lexTableNameEnd
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
case r == tableSep:
lx.ignore()
return lexTableNameStart
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("expected '.' or ']' to end table name, "+
"but got %q instead", r)
}
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("unexpected key separator %q", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.emit(itemKeyStart)
lx.push(lexKeyEnd)
return lexValue // reuse string lexing
default:
lx.ignore()
lx.emit(itemKeyStart)
return lexBareKey
}
}
// lexBareKey consumes the text of a bare key. Assumes that the first character
// (which is not whitespace) has not yet been consumed.
func lexBareKey(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.backup()
lx.emit(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emit(itemText)
return lexKeyEnd
default:
return lx.errorf("bare keys cannot contain %q", r)
}
}
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
// separator).
func lexKeyEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case r == keySep:
return lexSkip(lx, lexValue)
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("expected key separator %q, but got %q instead",
keySep, r)
}
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT newlines.
// In array syntax, the array states are responsible for ignoring newlines.
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexValue)
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
}
switch r {
case arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case inlineTableStart:
lx.ignore()
lx.emit(itemInlineTableStart)
return lexInlineTableValue
case stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
return lexMultilineString
}
lx.backup()
}
lx.ignore() // ignore the '"'
return lexString
case rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
return lexMultilineRawString
}
lx.backup()
}
lx.ignore() // ignore the "'"
return lexRawString
case '+', '-':
return lexNumberStart
case '.': // special error case, be kind to users
return lx.errorf("floats must start with a digit, not '.'")
}
if unicode.IsLetter(r) {
// Be permissive here; lexBool will give a nice error if the
// user wrote something like
// x = foo
// (i.e. not 'true' or 'false' but is something else word-like.)
lx.backup()
return lexBool
}
return lx.errorf("expected value but found %q instead", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and newlines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == arrayEnd:
// NOTE(caleb): The spec isn't clear about whether you can have
// a trailing comma or not, so we'll allow it.
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes everything between the end of an array value and
// the next value (or the end of the array): it ignores whitespace and newlines
// and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf(
"expected a comma or array terminator %q, but got %q instead",
arrayEnd, r,
)
}
// lexArrayEnd finishes the lexing of an array.
// It assumes that a ']' has just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexInlineTableValue consumes one key/value pair in an inline table.
// It assumes that '{' or ',' have already been consumed. Whitespace is ignored.
func lexInlineTableValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValue)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == inlineTableEnd:
return lexInlineTableEnd
}
lx.backup()
lx.push(lexInlineTableValueEnd)
return lexKeyStart
}
// lexInlineTableValueEnd consumes everything between the end of an inline table
// key/value pair and the next pair (or the end of the table):
// it ignores whitespace and expects either a ',' or a '}'.
func lexInlineTableValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValueEnd)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexInlineTableValue
case r == inlineTableEnd:
return lexInlineTableEnd
}
return lx.errorf("expected a comma or an inline table terminator %q, "+
"but got %q instead", inlineTableEnd, r)
}
// lexInlineTableEnd finishes the lexing of an inline table.
// It assumes that a '}' has just been consumed.
func lexInlineTableEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemInlineTableEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("strings cannot contain newlines")
case r == '\\':
lx.push(lexString)
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case '\\':
return lexMultilineStringEscape
case stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineString
}
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
// It assumes that the beginning "'" has already been consumed and ignored.
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("strings cannot contain newlines")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexRawString
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'''" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemRawMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineRawString
}
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
return lexMultilineString
}
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '\\':
return lx.pop()
case 'u':
return lexShortUnicodeEscape
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("invalid escape character %q; only the following "+
"escape characters are allowed: "+
`\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(`expected four hexadecimal digits after '\u', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
}
func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf(`expected eight hexadecimal digits after '\U', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either an integer, a float, or datetime.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumberOrDate
}
switch r {
case '_':
return lexNumber
case 'e', 'E':
return lexFloat
case '.':
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
// lexNumberOrDate consumes either an integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumberOrDate
}
switch r {
case '-':
return lexDatetime
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDatetime consumes a Datetime, to a first approximation.
// The parser validates that it matches one of the accepted formats.
func lexDatetime(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexDatetime
}
switch r {
case '-', 'T', ':', '.', 'Z':
return lexDatetime
}
lx.backup()
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that a sign
// has already been read, but that *no* digits have been consumed.
// lexNumberStart will move to the appropriate integer or float states.
func lexNumberStart(lx *lexer) stateFn {
// We MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexNumber
}
switch r {
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloat consumes the elements of a float. It allows any sequence of
// float-like characters, so floats emitted by the lexer are only a first
// approximation and must be validated by the parser.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
switch r {
case '_', '.', '-', '+', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexBool consumes a bool string: 'true' or 'false.
func lexBool(lx *lexer) stateFn {
var rs []rune
for {
r := lx.next()
if !unicode.IsLetter(r) {
lx.backup()
break
}
rs = append(rs, r)
}
s := string(rs)
switch s {
case "true", "false":
lx.emit(itemBool)
return lx.pop()
}
return lx.errorf("expected value but found %q instead", s)
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first newline character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' ||
r == '-'
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}

View File

@ -0,0 +1,592 @@
package toml
import (
"fmt"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panicf("%s", it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// Gets a string for a key (or part of a key in a table name).
func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
if !numUnderscoresOK(it.val) {
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseInt(val, 10, 64)
if err != nil {
// Distinguish integer values. Normally, it'd be a bug if the lexer
// provides an invalid integer, but it's possible that the number is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
return true
}
return false
})
for _, part := range parts {
if !numUnderscoresOK(part) {
p.panicf("Invalid float %q: underscores must be "+
"surrounded by digits", it.val)
}
}
if !numPeriodsOK(it.val) {
// As a special case, numbers like '123.' or '1.e2',
// which are valid as far as Go/strconv are concerned,
// must be rejected because TOML says that a fractional
// part consists of '.' followed by 1+ digits.
p.panicf("Invalid float %q: '.' must be followed "+
"by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseFloat(val, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.panicf("Invalid float value: %q", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
var t time.Time
var ok bool
var err error
for _, format := range []string{
"2006-01-02T15:04:05Z07:00",
"2006-01-02T15:04:05",
"2006-01-02",
} {
t, err = time.ParseInLocation(format, it.val, time.Local)
if err == nil {
ok = true
break
}
}
if !ok {
p.panicf("Invalid TOML Datetime: %q.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
case itemInlineTableStart:
var (
hash = make(map[string]interface{})
outerContext = p.context
outerKey = p.currentKey
)
p.context = append(p.context, p.currentKey)
p.currentKey = ""
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
if it.typ != itemKeyStart {
p.bug("Expected key start but instead found %q, around line %d",
it.val, p.approxLine)
}
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
// retrieve key
k := p.next()
p.approxLine = k.line
kname := p.keyString(k)
// retrieve value
p.currentKey = kname
val, typ := p.value(p.next())
// make sure we keep metadata up to date
p.setType(kname, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[kname] = val
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// numUnderscoresOK checks whether each underscore in s is surrounded by
// characters that are not underscores.
func numUnderscoresOK(s string) bool {
accept := false
for _, r := range s {
if r == '_' {
if !accept {
return false
}
accept = false
continue
}
accept = true
}
return accept
}
// numPeriodsOK checks whether every period in s is followed by a digit.
func numPeriodsOK(s string) bool {
period := false
for _, r := range s {
if period && !isDigit(r) {
return false
}
period = r == '.'
}
return !period
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
return s
}
return s[1:]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
}
}
return strings.Join(esc, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
continue
}
r += 1
if r >= len(s) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
replaced = append(replaced, escaped)
r += 9
}
}
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
if !utf8.ValidRune(rune(hex)) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

View File

@ -0,0 +1,91 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
if t1 == nil || t2 == nil {
return false
}
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemMultilineString:
return tomlString
case itemRawString:
return tomlString
case itemRawMultilineString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

View File

@ -0,0 +1,242 @@
package toml
// Struct field handling is adapted from code in encoding/json:
//
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the Go distribution.
import (
"reflect"
"sort"
"sync"
)
// A field represents a single field found in a struct.
type field struct {
name string // the name of the field (`toml` tag included)
tag bool // whether field has a `toml` tag
index []int // represents the depth of an anonymous field
typ reflect.Type // the type of the field
}
// byName sorts field by name, breaking ties with depth,
// then breaking ties with "name came from toml tag", then
// breaking ties with index sequence.
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that TOML should recognize for the given
// type. The algorithm is breadth-first search over the set of structs to
// include - the top struct and then any reachable anonymous structs.
func typeFields(t reflect.Type) []field {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.PkgPath != "" && !sf.Anonymous { // unexported
continue
}
opts := getOptions(sf.Tag)
if opts.skip {
continue
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
// Follow pointer.
ft = ft.Elem()
}
// Record found field and index sequence.
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := opts.name != ""
name := opts.name
if name == "" {
name = sf.Name
}
fields = append(fields, field{name, tagged, index, ft})
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
f := field{name: ft.Name(), index: index, typ: ft}
next = append(next, f)
}
}
}
}
sort.Sort(byName(fields))
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with TOML tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
return fields
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// TOML tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order. The winner
// must therefore be one with the shortest index length. Drop all
// longer entries, which is easy: just truncate the slice.
length := len(fields[0].index)
tagged := -1 // Index of first tagged field.
for i, f := range fields {
if len(f.index) > length {
fields = fields[:i]
break
}
if f.tag {
if tagged >= 0 {
// Multiple tagged fields at the same level: conflict.
// Return no field.
return field{}, false
}
tagged = i
}
}
if tagged >= 0 {
return fields[tagged], true
}
// All remaining fields have the same length. If there's more than one,
// we have a conflict (two fields named "X" at the same level) and we
// return no field.
if len(fields) > 1 {
return field{}, false
}
return fields[0], true
}
var fieldCache struct {
sync.RWMutex
m map[reflect.Type][]field
}
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) []field {
fieldCache.RLock()
f := fieldCache.m[t]
fieldCache.RUnlock()
if f != nil {
return f
}
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = typeFields(t)
if f == nil {
f = []field{}
}
fieldCache.Lock()
if fieldCache.m == nil {
fieldCache.m = map[reflect.Type][]field{}
}
fieldCache.m[t] = f
fieldCache.Unlock()
return f
}

View File

@ -0,0 +1,90 @@
# This is the official list of Go-MySQL-Driver authors for copyright purposes.
# If you are submitting a patch, please add your name or the name of the
# organization which holds the copyright to this list in alphabetical order.
# Names should be added to this file as
# Name <email address>
# The email address is not required for organizations.
# Please keep the list sorted.
# Individual Persons
Aaron Hopkins <go-sql-driver at die.net>
Achille Roussel <achille.roussel at gmail.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com>
Arne Hormann <arnehormann at gmail.com>
Asta Xie <xiemengjun at gmail.com>
Bulat Gaifullin <gaifullinbf at gmail.com>
Carlos Nieto <jose.carlos at menteslibres.net>
Chris Moos <chris at tech9computers.com>
Craig Wilson <craiggwilson at gmail.com>
Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
ICHINOSE Shogo <shogo82148 at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lennart Rudolph <lrudolph at hmc.edu>
Leonardo YongUk Kim <dalinaum at gmail.com>
Linh Tran Tuan <linhduonggnu at gmail.com>
Lion Yang <lion at aosc.xyz>
Luca Looz <luca.looz92 at gmail.com>
Lucas Liu <extrafliu at gmail.com>
Luke Scott <luke at webconnex.com>
Maciej Zimnoch <maciej.zimnoch at codilime.com>
Michael Woolnough <michael.woolnough at gmail.com>
Nicola Peduzzi <thenikso at gmail.com>
Olivier Mengué <dolmen at cpan.org>
oscarzhao <oscarzhaosl at gmail.com>
Paul Bonser <misterpib at gmail.com>
Peter Schultz <peter.schultz at classmarkets.com>
Rebecca Chin <rchin at pivotal.io>
Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com>
Robert Russell <robert at rrbrussell.com>
Runrioter Wung <runrioter at gmail.com>
Shuode Li <elemount at qq.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
Zhenye Xie <xiezhenye at gmail.com>
# Organizations
Barracuda Networks, Inc.
Counting Ltd.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Percona LLC
Pivotal Inc.
Stripe Inc.

View File

@ -0,0 +1,373 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

View File

@ -0,0 +1,19 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build appengine
package mysql
import (
"google.golang.org/appengine/cloudsql"
)
func init() {
RegisterDial("cloudsql", cloudsql.Dial)
}

View File

@ -0,0 +1,420 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"sync"
)
// server pub keys registry
var (
serverPubKeyLock sync.RWMutex
serverPubKeyRegistry map[string]*rsa.PublicKey
)
// RegisterServerPubKey registers a server RSA public key which can be used to
// send data in a secure manner to the server without receiving the public key
// in a potentially insecure way from the server first.
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
//
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
// after registering it and may not be modified.
//
// data, err := ioutil.ReadFile("mykey.pem")
// if err != nil {
// log.Fatal(err)
// }
//
// block, _ := pem.Decode(data)
// if block == nil || block.Type != "PUBLIC KEY" {
// log.Fatal("failed to decode PEM block containing public key")
// }
//
// pub, err := x509.ParsePKIXPublicKey(block.Bytes)
// if err != nil {
// log.Fatal(err)
// }
//
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
// } else {
// log.Fatal("not a RSA public key")
// }
//
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
serverPubKeyLock.Lock()
if serverPubKeyRegistry == nil {
serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
}
serverPubKeyRegistry[name] = pubKey
serverPubKeyLock.Unlock()
}
// DeregisterServerPubKey removes the public key registered with the given name.
func DeregisterServerPubKey(name string) {
serverPubKeyLock.Lock()
if serverPubKeyRegistry != nil {
delete(serverPubKeyRegistry, name)
}
serverPubKeyLock.Unlock()
}
func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
serverPubKeyLock.RLock()
if v, ok := serverPubKeyRegistry[name]; ok {
pubKey = v
}
serverPubKeyLock.RUnlock()
return
}
// Hash password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
seed1, seed2 uint32
}
const myRndMaxVal = 0x3FFFFFFF
// Pseudo random number generator
func newMyRnd(seed1, seed2 uint32) *myRnd {
return &myRnd{
seed1: seed1 % myRndMaxVal,
seed2: seed2 % myRndMaxVal,
}
}
// Tested to be equivalent to MariaDB's floating point variant
// http://play.golang.org/p/QHvhd4qved
// http://play.golang.org/p/RG0q4ElWDx
func (r *myRnd) NextByte() byte {
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
}
// Generate binary hash from byte string using insecure pre 4.1 method
func pwHash(password []byte) (result [2]uint32) {
var add uint32 = 7
var tmp uint32
result[0] = 1345345333
result[1] = 0x12345671
for _, c := range password {
// skip spaces and tabs in password
if c == ' ' || c == '\t' {
continue
}
tmp = uint32(c)
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
result[1] += (result[1] << 8) ^ result[0]
add += tmp
}
// Remove sign bit (1<<31)-1)
result[0] &= 0x7FFFFFFF
result[1] &= 0x7FFFFFFF
return
}
// Hash password using insecure pre 4.1 method
func scrambleOldPassword(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
scramble = scramble[:8]
hashPw := pwHash([]byte(password))
hashSc := pwHash(scramble)
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
var out [8]byte
for i := range out {
out[i] = r.NextByte() + 64
}
mask := r.NextByte()
for i := range out {
out[i] ^= mask
}
return out[:]
}
// Hash password using 4.1+ method (SHA1)
func scramblePassword(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write([]byte(password))
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
// Hash password using MySQL 8+ method (SHA256)
func scrambleSHA256Password(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
crypt := sha256.New()
crypt.Write([]byte(password))
message1 := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1)
message1Hash := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1Hash)
crypt.Write(scramble)
message2 := crypt.Sum(nil)
for i := range message1 {
message1[i] ^= message2[i]
}
return message1
}
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
plain := make([]byte, len(password)+1)
copy(plain, password)
for i := range plain {
j := i % len(seed)
plain[i] ^= seed[j]
}
sha1 := sha1.New()
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
}
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
if err != nil {
return err
}
return mc.writeAuthSwitchPacket(enc, false)
}
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
switch plugin {
case "caching_sha2_password":
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
return authResp, false, nil
case "mysql_old_password":
if !mc.cfg.AllowOldPasswords {
return nil, false, ErrOldPassword
}
// Note: there are edge cases where this should work but doesn't;
// this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
return authResp, true, nil
case "mysql_clear_password":
if !mc.cfg.AllowCleartextPasswords {
return nil, false, ErrCleartextPassword
}
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
return []byte(mc.cfg.Passwd), true, nil
case "mysql_native_password":
if !mc.cfg.AllowNativePasswords {
return nil, false, ErrNativePassword
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
// Native password authentication only need and will need 20-byte challenge.
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
return authResp, false, nil
case "sha256_password":
if len(mc.cfg.Passwd) == 0 {
return nil, true, nil
}
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
return []byte(mc.cfg.Passwd), true, nil
}
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
return []byte{1}, false, nil
}
// encrypted password
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, false, err
default:
errLog.Print("unknown auth plugin:", plugin)
return nil, false, ErrUnknownPlugin
}
}
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
// Read Result Packet
authData, newPlugin, err := mc.readAuthResult()
if err != nil {
return err
}
// handle auth plugin switch, if requested
if newPlugin != "" {
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
// sent and we have to keep using the cipher sent in the init packet.
if authData == nil {
authData = oldAuthData
} else {
// copy data from read buffer to owned slice
copy(oldAuthData, authData)
}
plugin = newPlugin
authResp, addNUL, err := mc.auth(authData, plugin)
if err != nil {
return err
}
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
return err
}
// Read Result Packet
authData, newPlugin, err = mc.readAuthResult()
if err != nil {
return err
}
// Do not allow to change the auth plugin more than once
if newPlugin != "" {
return ErrMalformPkt
}
}
switch plugin {
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
case "caching_sha2_password":
switch len(authData) {
case 0:
return nil // auth successful
case 1:
switch authData[0] {
case cachingSha2PasswordFastAuthSuccess:
if err = mc.readResultOK(); err == nil {
return nil // auth successful
}
case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
if err != nil {
return err
}
} else {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data := mc.buf.takeSmallBuffer(4 + 1)
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)
// parse public key
data, err := mc.readPacket()
if err != nil {
return err
}
block, _ := pem.Decode(data[1:])
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}
pubKey = pkix.(*rsa.PublicKey)
}
// send encrypted password
err = mc.sendEncryptedPassword(oldAuthData, pubKey)
if err != nil {
return err
}
}
return mc.readResultOK()
default:
return ErrMalformPkt
}
default:
return ErrMalformPkt
}
case "sha256_password":
switch len(authData) {
case 0:
return nil // auth successful
default:
block, _ := pem.Decode(authData)
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}
// send encrypted password
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey))
if err != nil {
return err
}
return mc.readResultOK()
}
default:
return nil // auth successful
}
return err
}

View File

@ -0,0 +1,147 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"io"
"net"
"time"
)
const defaultBufSize = 4096
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
nc net.Conn
idx int
length int
timeout time.Duration
}
func newBuffer(nc net.Conn) buffer {
var b [defaultBufSize]byte
return buffer{
buf: b[:],
nc: nc,
}
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
n := b.length
// move existing data to the beginning
if n > 0 && b.idx > 0 {
copy(b.buf[0:n], b.buf[b.idx:])
}
// grow buffer if necessary
// TODO: let the buffer shrink again at some point
// Maybe keep the org buf slice and swap back?
if need > len(b.buf) {
// Round up to the next multiple of the default size
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
copy(newBuf, b.buf)
b.buf = newBuf
}
b.idx = 0
for {
if b.timeout > 0 {
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
n += nn
switch err {
case nil:
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
}
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
}
// returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
if b.length > 0 {
return nil
}
// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
}
return make([]byte, length)
}
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length > 0 {
return nil
}
return b.buf[:length]
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
if b.length > 0 {
return nil
}
return b.buf
}

View File

@ -0,0 +1,251 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const defaultCollation = "utf8_general_ci"
const binaryCollation = "binary"
// A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query:
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS
var collations = map[string]byte{
"big5_chinese_ci": 1,
"latin2_czech_cs": 2,
"dec8_swedish_ci": 3,
"cp850_general_ci": 4,
"latin1_german1_ci": 5,
"hp8_english_ci": 6,
"koi8r_general_ci": 7,
"latin1_swedish_ci": 8,
"latin2_general_ci": 9,
"swe7_swedish_ci": 10,
"ascii_general_ci": 11,
"ujis_japanese_ci": 12,
"sjis_japanese_ci": 13,
"cp1251_bulgarian_ci": 14,
"latin1_danish_ci": 15,
"hebrew_general_ci": 16,
"tis620_thai_ci": 18,
"euckr_korean_ci": 19,
"latin7_estonian_cs": 20,
"latin2_hungarian_ci": 21,
"koi8u_general_ci": 22,
"cp1251_ukrainian_ci": 23,
"gb2312_chinese_ci": 24,
"greek_general_ci": 25,
"cp1250_general_ci": 26,
"latin2_croatian_ci": 27,
"gbk_chinese_ci": 28,
"cp1257_lithuanian_ci": 29,
"latin5_turkish_ci": 30,
"latin1_german2_ci": 31,
"armscii8_general_ci": 32,
"utf8_general_ci": 33,
"cp1250_czech_cs": 34,
"ucs2_general_ci": 35,
"cp866_general_ci": 36,
"keybcs2_general_ci": 37,
"macce_general_ci": 38,
"macroman_general_ci": 39,
"cp852_general_ci": 40,
"latin7_general_ci": 41,
"latin7_general_cs": 42,
"macce_bin": 43,
"cp1250_croatian_ci": 44,
"utf8mb4_general_ci": 45,
"utf8mb4_bin": 46,
"latin1_bin": 47,
"latin1_general_ci": 48,
"latin1_general_cs": 49,
"cp1251_bin": 50,
"cp1251_general_ci": 51,
"cp1251_general_cs": 52,
"macroman_bin": 53,
"utf16_general_ci": 54,
"utf16_bin": 55,
"utf16le_general_ci": 56,
"cp1256_general_ci": 57,
"cp1257_bin": 58,
"cp1257_general_ci": 59,
"utf32_general_ci": 60,
"utf32_bin": 61,
"utf16le_bin": 62,
"binary": 63,
"armscii8_bin": 64,
"ascii_bin": 65,
"cp1250_bin": 66,
"cp1256_bin": 67,
"cp866_bin": 68,
"dec8_bin": 69,
"greek_bin": 70,
"hebrew_bin": 71,
"hp8_bin": 72,
"keybcs2_bin": 73,
"koi8r_bin": 74,
"koi8u_bin": 75,
"latin2_bin": 77,
"latin5_bin": 78,
"latin7_bin": 79,
"cp850_bin": 80,
"cp852_bin": 81,
"swe7_bin": 82,
"utf8_bin": 83,
"big5_bin": 84,
"euckr_bin": 85,
"gb2312_bin": 86,
"gbk_bin": 87,
"sjis_bin": 88,
"tis620_bin": 89,
"ucs2_bin": 90,
"ujis_bin": 91,
"geostd8_general_ci": 92,
"geostd8_bin": 93,
"latin1_spanish_ci": 94,
"cp932_japanese_ci": 95,
"cp932_bin": 96,
"eucjpms_japanese_ci": 97,
"eucjpms_bin": 98,
"cp1250_polish_ci": 99,
"utf16_unicode_ci": 101,
"utf16_icelandic_ci": 102,
"utf16_latvian_ci": 103,
"utf16_romanian_ci": 104,
"utf16_slovenian_ci": 105,
"utf16_polish_ci": 106,
"utf16_estonian_ci": 107,
"utf16_spanish_ci": 108,
"utf16_swedish_ci": 109,
"utf16_turkish_ci": 110,
"utf16_czech_ci": 111,
"utf16_danish_ci": 112,
"utf16_lithuanian_ci": 113,
"utf16_slovak_ci": 114,
"utf16_spanish2_ci": 115,
"utf16_roman_ci": 116,
"utf16_persian_ci": 117,
"utf16_esperanto_ci": 118,
"utf16_hungarian_ci": 119,
"utf16_sinhala_ci": 120,
"utf16_german2_ci": 121,
"utf16_croatian_ci": 122,
"utf16_unicode_520_ci": 123,
"utf16_vietnamese_ci": 124,
"ucs2_unicode_ci": 128,
"ucs2_icelandic_ci": 129,
"ucs2_latvian_ci": 130,
"ucs2_romanian_ci": 131,
"ucs2_slovenian_ci": 132,
"ucs2_polish_ci": 133,
"ucs2_estonian_ci": 134,
"ucs2_spanish_ci": 135,
"ucs2_swedish_ci": 136,
"ucs2_turkish_ci": 137,
"ucs2_czech_ci": 138,
"ucs2_danish_ci": 139,
"ucs2_lithuanian_ci": 140,
"ucs2_slovak_ci": 141,
"ucs2_spanish2_ci": 142,
"ucs2_roman_ci": 143,
"ucs2_persian_ci": 144,
"ucs2_esperanto_ci": 145,
"ucs2_hungarian_ci": 146,
"ucs2_sinhala_ci": 147,
"ucs2_german2_ci": 148,
"ucs2_croatian_ci": 149,
"ucs2_unicode_520_ci": 150,
"ucs2_vietnamese_ci": 151,
"ucs2_general_mysql500_ci": 159,
"utf32_unicode_ci": 160,
"utf32_icelandic_ci": 161,
"utf32_latvian_ci": 162,
"utf32_romanian_ci": 163,
"utf32_slovenian_ci": 164,
"utf32_polish_ci": 165,
"utf32_estonian_ci": 166,
"utf32_spanish_ci": 167,
"utf32_swedish_ci": 168,
"utf32_turkish_ci": 169,
"utf32_czech_ci": 170,
"utf32_danish_ci": 171,
"utf32_lithuanian_ci": 172,
"utf32_slovak_ci": 173,
"utf32_spanish2_ci": 174,
"utf32_roman_ci": 175,
"utf32_persian_ci": 176,
"utf32_esperanto_ci": 177,
"utf32_hungarian_ci": 178,
"utf32_sinhala_ci": 179,
"utf32_german2_ci": 180,
"utf32_croatian_ci": 181,
"utf32_unicode_520_ci": 182,
"utf32_vietnamese_ci": 183,
"utf8_unicode_ci": 192,
"utf8_icelandic_ci": 193,
"utf8_latvian_ci": 194,
"utf8_romanian_ci": 195,
"utf8_slovenian_ci": 196,
"utf8_polish_ci": 197,
"utf8_estonian_ci": 198,
"utf8_spanish_ci": 199,
"utf8_swedish_ci": 200,
"utf8_turkish_ci": 201,
"utf8_czech_ci": 202,
"utf8_danish_ci": 203,
"utf8_lithuanian_ci": 204,
"utf8_slovak_ci": 205,
"utf8_spanish2_ci": 206,
"utf8_roman_ci": 207,
"utf8_persian_ci": 208,
"utf8_esperanto_ci": 209,
"utf8_hungarian_ci": 210,
"utf8_sinhala_ci": 211,
"utf8_german2_ci": 212,
"utf8_croatian_ci": 213,
"utf8_unicode_520_ci": 214,
"utf8_vietnamese_ci": 215,
"utf8_general_mysql500_ci": 223,
"utf8mb4_unicode_ci": 224,
"utf8mb4_icelandic_ci": 225,
"utf8mb4_latvian_ci": 226,
"utf8mb4_romanian_ci": 227,
"utf8mb4_slovenian_ci": 228,
"utf8mb4_polish_ci": 229,
"utf8mb4_estonian_ci": 230,
"utf8mb4_spanish_ci": 231,
"utf8mb4_swedish_ci": 232,
"utf8mb4_turkish_ci": 233,
"utf8mb4_czech_ci": 234,
"utf8mb4_danish_ci": 235,
"utf8mb4_lithuanian_ci": 236,
"utf8mb4_slovak_ci": 237,
"utf8mb4_spanish2_ci": 238,
"utf8mb4_roman_ci": 239,
"utf8mb4_persian_ci": 240,
"utf8mb4_esperanto_ci": 241,
"utf8mb4_hungarian_ci": 242,
"utf8mb4_sinhala_ci": 243,
"utf8mb4_german2_ci": 244,
"utf8mb4_croatian_ci": 245,
"utf8mb4_unicode_520_ci": 246,
"utf8mb4_vietnamese_ci": 247,
}
// A blacklist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[string]bool{
"big5_chinese_ci": true,
"sjis_japanese_ci": true,
"gbk_chinese_ci": true,
"big5_bin": true,
"gb2312_bin": true,
"gbk_bin": true,
"sjis_bin": true,
"cp932_japanese_ci": true,
"cp932_bin": true,
}

View File

@ -0,0 +1,654 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"io"
"net"
"strconv"
"strings"
"time"
)
// a copy of context.Context for Go 1.7 and earlier
type mysqlContext interface {
Done() <-chan struct{}
Err() error
// defined in context.Context, but not used in this driver:
// Deadline() (deadline time.Time, ok bool)
// Value(key interface{}) interface{}
}
type mysqlConn struct {
buf buffer
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
// for context support (Go 1.8+)
watching bool
watcher chan<- mysqlContext
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.Params {
switch param {
// Charset
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
return
}
func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
if readOnly {
q = "START TRANSACTION READ ONLY"
} else {
q = "START TRANSACTION"
}
err := mc.exec(q)
if err == nil {
return &mysqlTx{mc}, err
}
return nil, mc.markBadConn(err)
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.closed.IsSet() {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
return
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
if !mc.closed.TrySet(true) {
return
}
// Makes cleanup idempotent
close(mc.closech)
if mc.netConn == nil {
return
}
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
}
}
func (mc *mysqlConn) error() error {
if mc.closed.IsSet() {
if err := mc.canceled.Value(); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
return nil, mc.markBadConn(err)
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf := mc.buf.takeCompleteBuffer()
if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return "", ErrInvalidConn
}
buf = buf[:0]
argPos := 0
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.Loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
year1 := year % 100
month := v.Month()
day := v.Day()
hour := v.Hour()
minute := v.Minute()
second := v.Second()
micro := v.Nanosecond() / 1000
buf = append(buf, []byte{
'\'',
digits10[year100], digits01[year100],
digits10[year1], digits01[year1],
'-',
digits10[month], digits01[month],
'-',
digits10[day], digits01[day],
' ',
digits10[hour], digits01[hour],
':',
digits10[minute], digits01[minute],
':',
digits10[second], digits01[second],
}...)
if micro != 0 {
micro10000 := micro / 10000
micro100 := micro / 100 % 100
micro1 := micro % 100
buf = append(buf, []byte{
'.',
digits10[micro10000], digits01[micro10000],
digits10[micro100], digits01[micro100],
digits10[micro1], digits01[micro1],
}...)
}
buf = append(buf, '\'')
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, mc.markBadConn(err)
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
return err
}
}
return mc.discardResults()
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
return nil, mc.markBadConn(err)
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}
// finish is called when the query has succeeded.
func (mc *mysqlConn) finish() {
if !mc.watching || mc.finished == nil {
return
}
select {
case mc.finished <- struct{}{}:
mc.watching = false
case <-mc.closech:
}
}
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return
}
return mc.readResultOK()
}
// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
level, err := mapIsolationLevel(opts.Isolation)
if err != nil {
return nil, err
}
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
if err != nil {
return nil, err
}
}
return mc.begin(opts.ReadOnly)
}
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
return mc.Exec(query, dargs)
}
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
stmt, err := mc.Prepare(query)
mc.finish()
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}
return stmt, nil
}
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
defer stmt.mc.finish()
return stmt.Exec(dargs)
}
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
if ctx.Done() == nil {
return nil
}
mc.watching = true
select {
default:
case <-ctx.Done():
return ctx.Err()
}
if mc.watcher == nil {
return nil
}
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan mysqlContext, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx mysqlContext
select {
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() {
return driver.ErrBadConn
}
return nil
}

View File

@ -0,0 +1,174 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const (
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"
)
// MySQL constants documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
const (
iOK byte = 0x00
iAuthMoreData byte = 0x01
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
)
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
type clientFlag uint32
const (
clientLongPassword clientFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
clientNoSchema
clientCompress
clientODBC
clientLocalFiles
clientIgnoreSpace
clientProtocol41
clientInteractive
clientSSL
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientMultiStatements
clientMultiResults
clientPSMultiResults
clientPluginAuth
clientConnectAttrs
clientPluginAuthLenEncClientData
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
)
const (
comQuit byte = iota + 1
comInitDB
comQuery
comFieldList
comCreateDB
comDropDB
comRefresh
comShutdown
comStatistics
comProcessInfo
comConnect
comProcessKill
comDebug
comPing
comTime
comDelayedInsert
comChangeUser
comBinlogDump
comTableDump
comConnectOut
comRegisterSlave
comStmtPrepare
comStmtExecute
comStmtSendLongData
comStmtClose
comStmtReset
comSetOption
comStmtFetch
)
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
type fieldType byte
const (
fieldTypeDecimal fieldType = iota
fieldTypeTiny
fieldTypeShort
fieldTypeLong
fieldTypeFloat
fieldTypeDouble
fieldTypeNULL
fieldTypeTimestamp
fieldTypeLongLong
fieldTypeInt24
fieldTypeDate
fieldTypeTime
fieldTypeDateTime
fieldTypeYear
fieldTypeNewDate
fieldTypeVarChar
fieldTypeBit
)
const (
fieldTypeJSON fieldType = iota + 0xf5
fieldTypeNewDecimal
fieldTypeEnum
fieldTypeSet
fieldTypeTinyBLOB
fieldTypeMediumBLOB
fieldTypeLongBLOB
fieldTypeBLOB
fieldTypeVarString
fieldTypeString
fieldTypeGeometry
)
type fieldFlag uint16
const (
flagNotNULL fieldFlag = 1 << iota
flagPriKey
flagUniqueKey
flagMultipleKey
flagBLOB
flagUnsigned
flagZeroFill
flagBinary
flagEnum
flagAutoIncrement
flagTimestamp
flagSet
flagUnknown1
flagUnknown2
flagUnknown3
flagUnknown4
)
// http://dev.mysql.com/doc/internals/en/status-flags.html
type statusFlag uint16
const (
statusInTrans statusFlag = 1 << iota
statusInAutocommit
statusReserved // Not in documentation
statusMoreResultsExists
statusNoGoodIndexUsed
statusNoIndexUsed
statusCursorExists
statusLastRowSent
statusDbDropped
statusNoBackslashEscapes
statusMetadataChanged
statusQueryWasSlow
statusPsOutParams
statusInTransReadonly
statusSessionStateChanged
)
const (
cachingSha2PasswordRequestPublicKey = 2
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)

View File

@ -0,0 +1,165 @@
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Package mysql provides a MySQL driver for Go's database/sql package.
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
// See https://github.com/go-sql-driver/mysql#usage for details
package mysql
import (
"database/sql"
"database/sql/driver"
"net"
"sync"
)
// MySQLDriver is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
// DialFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDial
type DialFunc func(addr string) (net.Conn, error)
var (
dialsLock sync.RWMutex
dials map[string]DialFunc
)
// RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
func RegisterDial(net string, dial DialFunc) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials == nil {
dials = make(map[string]DialFunc)
}
dials[net] = dial
}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
}
mc.cfg, err = ParseDSN(dsn)
if err != nil {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
// Call startWatcher for context support (From Go 1.8)
mc.startWatcher()
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}
if plugin == "" {
plugin = defaultAuthPlugin
}
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, addNUL, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
func init() {
sql.Register("mysql", &MySQLDriver{})
}

View File

@ -0,0 +1,611 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"crypto/rsa"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
var (
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
)
// Config is a configuration parsed from a DSN string.
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowNativePasswords bool // Allows the native password authentication method
AllowOldPasswords bool // Allows the old insecure password method
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
}
// NewConfig creates a new Config and sets default values.
func NewConfig() *Config {
return &Config{
Collation: defaultCollation,
Loc: time.UTC,
MaxAllowedPacket: defaultMaxAllowedPacket,
AllowNativePasswords: true,
}
}
func (cfg *Config) normalize() error {
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
return errInvalidDSNUnsafeCollation
}
// Set default network if empty
if cfg.Net == "" {
cfg.Net = "tcp"
}
// Set default address if empty
if cfg.Addr == "" {
switch cfg.Net {
case "tcp":
cfg.Addr = "127.0.0.1:3306"
case "unix":
cfg.Addr = "/tmp/mysql.sock"
default:
return errors.New("default addr for network '" + cfg.Net + "' unknown")
}
} else if cfg.Net == "tcp" {
cfg.Addr = ensureHavePort(cfg.Addr)
}
if cfg.tls != nil {
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
}
}
}
return nil
}
// FormatDSN formats the given Config into a DSN string which can be passed to
// the driver.
func (cfg *Config) FormatDSN() string {
var buf bytes.Buffer
// [username[:password]@]
if len(cfg.User) > 0 {
buf.WriteString(cfg.User)
if len(cfg.Passwd) > 0 {
buf.WriteByte(':')
buf.WriteString(cfg.Passwd)
}
buf.WriteByte('@')
}
// [protocol[(address)]]
if len(cfg.Net) > 0 {
buf.WriteString(cfg.Net)
if len(cfg.Addr) > 0 {
buf.WriteByte('(')
buf.WriteString(cfg.Addr)
buf.WriteByte(')')
}
}
// /dbname
buf.WriteByte('/')
buf.WriteString(cfg.DBName)
// [?param1=value1&...&paramN=valueN]
hasParam := false
if cfg.AllowAllFiles {
hasParam = true
buf.WriteString("?allowAllFiles=true")
}
if cfg.AllowCleartextPasswords {
if hasParam {
buf.WriteString("&allowCleartextPasswords=true")
} else {
hasParam = true
buf.WriteString("?allowCleartextPasswords=true")
}
}
if !cfg.AllowNativePasswords {
if hasParam {
buf.WriteString("&allowNativePasswords=false")
} else {
hasParam = true
buf.WriteString("?allowNativePasswords=false")
}
}
if cfg.AllowOldPasswords {
if hasParam {
buf.WriteString("&allowOldPasswords=true")
} else {
hasParam = true
buf.WriteString("?allowOldPasswords=true")
}
}
if cfg.ClientFoundRows {
if hasParam {
buf.WriteString("&clientFoundRows=true")
} else {
hasParam = true
buf.WriteString("?clientFoundRows=true")
}
}
if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
if hasParam {
buf.WriteString("&collation=")
} else {
hasParam = true
buf.WriteString("?collation=")
}
buf.WriteString(col)
}
if cfg.ColumnsWithAlias {
if hasParam {
buf.WriteString("&columnsWithAlias=true")
} else {
hasParam = true
buf.WriteString("?columnsWithAlias=true")
}
}
if cfg.InterpolateParams {
if hasParam {
buf.WriteString("&interpolateParams=true")
} else {
hasParam = true
buf.WriteString("?interpolateParams=true")
}
}
if cfg.Loc != time.UTC && cfg.Loc != nil {
if hasParam {
buf.WriteString("&loc=")
} else {
hasParam = true
buf.WriteString("?loc=")
}
buf.WriteString(url.QueryEscape(cfg.Loc.String()))
}
if cfg.MultiStatements {
if hasParam {
buf.WriteString("&multiStatements=true")
} else {
hasParam = true
buf.WriteString("?multiStatements=true")
}
}
if cfg.ParseTime {
if hasParam {
buf.WriteString("&parseTime=true")
} else {
hasParam = true
buf.WriteString("?parseTime=true")
}
}
if cfg.ReadTimeout > 0 {
if hasParam {
buf.WriteString("&readTimeout=")
} else {
hasParam = true
buf.WriteString("?readTimeout=")
}
buf.WriteString(cfg.ReadTimeout.String())
}
if cfg.RejectReadOnly {
if hasParam {
buf.WriteString("&rejectReadOnly=true")
} else {
hasParam = true
buf.WriteString("?rejectReadOnly=true")
}
}
if len(cfg.ServerPubKey) > 0 {
if hasParam {
buf.WriteString("&serverPubKey=")
} else {
hasParam = true
buf.WriteString("?serverPubKey=")
}
buf.WriteString(url.QueryEscape(cfg.ServerPubKey))
}
if cfg.Timeout > 0 {
if hasParam {
buf.WriteString("&timeout=")
} else {
hasParam = true
buf.WriteString("?timeout=")
}
buf.WriteString(cfg.Timeout.String())
}
if len(cfg.TLSConfig) > 0 {
if hasParam {
buf.WriteString("&tls=")
} else {
hasParam = true
buf.WriteString("?tls=")
}
buf.WriteString(url.QueryEscape(cfg.TLSConfig))
}
if cfg.WriteTimeout > 0 {
if hasParam {
buf.WriteString("&writeTimeout=")
} else {
hasParam = true
buf.WriteString("?writeTimeout=")
}
buf.WriteString(cfg.WriteTimeout.String())
}
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
if hasParam {
buf.WriteString("&maxAllowedPacket=")
} else {
hasParam = true
buf.WriteString("?maxAllowedPacket=")
}
buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
}
// other params
if cfg.Params != nil {
var params []string
for param := range cfg.Params {
params = append(params, param)
}
sort.Strings(params)
for _, param := range params {
if hasParam {
buf.WriteByte('&')
} else {
hasParam = true
buf.WriteByte('?')
}
buf.WriteString(param)
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(cfg.Params[param]))
}
}
return buf.String()
}
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *Config, err error) {
// New config with some default values
cfg = NewConfig()
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.Passwd = dsn[k+1 : j]
break
}
}
cfg.User = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.Addr = dsn[k+1 : i-1]
break
}
}
cfg.Net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.DBName = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
if err = cfg.normalize(); err != nil {
return nil, err
}
return
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// cfg params
switch value := param[1]; param[0] {
// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.AllowAllFiles, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use cleartext authentication mode (MySQL 5.5.10+)
case "allowCleartextPasswords":
var isBool bool
cfg.AllowCleartextPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use native password authentication
case "allowNativePasswords":
var isBool bool
cfg.AllowNativePasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.AllowOldPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.ClientFoundRows, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Collation
case "collation":
cfg.Collation = value
break
case "columnsWithAlias":
var isBool bool
cfg.ColumnsWithAlias, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Compression
case "compress":
return errors.New("compression not implemented yet")
// Enable client side placeholder substitution
case "interpolateParams":
var isBool bool
cfg.InterpolateParams, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.Loc, err = time.LoadLocation(value)
if err != nil {
return
}
// multiple statements in one query
case "multiStatements":
var isBool bool
cfg.MultiStatements, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// time.Time parsing
case "parseTime":
var isBool bool
cfg.ParseTime, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// I/O read Timeout
case "readTimeout":
cfg.ReadTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
// Reject read-only connections
case "rejectReadOnly":
var isBool bool
cfg.RejectReadOnly, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Server public key
case "serverPubKey":
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for server pub key name: %v", err)
}
if pubKey := getServerPubKey(name); pubKey != nil {
cfg.ServerPubKey = name
cfg.pubKey = pubKey
} else {
return errors.New("invalid value / unknown server pub key name: " + name)
}
// Strict mode
case "strict":
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
// Dial Timeout
case "timeout":
cfg.Timeout, err = time.ParseDuration(value)
if err != nil {
return
}
// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.TLSConfig = "true"
cfg.tls = &tls.Config{}
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else {
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for TLS config name: %v", err)
}
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
cfg.TLSConfig = name
cfg.tls = tlsConfig
} else {
return errors.New("invalid value / unknown config name: " + name)
}
}
// I/O write Timeout
case "writeTimeout":
cfg.WriteTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
case "maxAllowedPacket":
cfg.MaxAllowedPacket, err = strconv.Atoi(value)
if err != nil {
return
}
default:
// lazy init
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}
return
}
func ensureHavePort(addr string) string {
if _, _, err := net.SplitHostPort(addr); err != nil {
return net.JoinHostPort(addr, "3306")
}
return addr
}

View File

@ -0,0 +1,65 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"errors"
"fmt"
"log"
"os"
)
// Various errors the driver might return. Can change between driver versions.
var (
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication.")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend.
// See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection")
)
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
// Logger is used to log critical error messages.
type Logger interface {
Print(v ...interface{})
}
// SetLogger is used to set the logger for critical errors.
// The initial logger is os.Stderr.
func SetLogger(logger Logger) error {
if logger == nil {
return errors.New("logger is nil")
}
errLog = logger
return nil
}
// MySQLError is an error type which represents a single MySQL error
type MySQLError struct {
Number uint16
Message string
}
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}

View File

@ -0,0 +1,194 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql"
"reflect"
)
func (mf *mysqlField) typeDatabaseName() string {
switch mf.fieldType {
case fieldTypeBit:
return "BIT"
case fieldTypeBLOB:
if mf.charSet != collations[binaryCollation] {
return "TEXT"
}
return "BLOB"
case fieldTypeDate:
return "DATE"
case fieldTypeDateTime:
return "DATETIME"
case fieldTypeDecimal:
return "DECIMAL"
case fieldTypeDouble:
return "DOUBLE"
case fieldTypeEnum:
return "ENUM"
case fieldTypeFloat:
return "FLOAT"
case fieldTypeGeometry:
return "GEOMETRY"
case fieldTypeInt24:
return "MEDIUMINT"
case fieldTypeJSON:
return "JSON"
case fieldTypeLong:
return "INT"
case fieldTypeLongBLOB:
if mf.charSet != collations[binaryCollation] {
return "LONGTEXT"
}
return "LONGBLOB"
case fieldTypeLongLong:
return "BIGINT"
case fieldTypeMediumBLOB:
if mf.charSet != collations[binaryCollation] {
return "MEDIUMTEXT"
}
return "MEDIUMBLOB"
case fieldTypeNewDate:
return "DATE"
case fieldTypeNewDecimal:
return "DECIMAL"
case fieldTypeNULL:
return "NULL"
case fieldTypeSet:
return "SET"
case fieldTypeShort:
return "SMALLINT"
case fieldTypeString:
if mf.charSet == collations[binaryCollation] {
return "BINARY"
}
return "CHAR"
case fieldTypeTime:
return "TIME"
case fieldTypeTimestamp:
return "TIMESTAMP"
case fieldTypeTiny:
return "TINYINT"
case fieldTypeTinyBLOB:
if mf.charSet != collations[binaryCollation] {
return "TINYTEXT"
}
return "TINYBLOB"
case fieldTypeVarChar:
if mf.charSet == collations[binaryCollation] {
return "VARBINARY"
}
return "VARCHAR"
case fieldTypeVarString:
if mf.charSet == collations[binaryCollation] {
return "VARBINARY"
}
return "VARCHAR"
case fieldTypeYear:
return "YEAR"
default:
return ""
}
}
var (
scanTypeFloat32 = reflect.TypeOf(float32(0))
scanTypeFloat64 = reflect.TypeOf(float64(0))
scanTypeInt8 = reflect.TypeOf(int8(0))
scanTypeInt16 = reflect.TypeOf(int16(0))
scanTypeInt32 = reflect.TypeOf(int32(0))
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(NullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
scanTypeUnknown = reflect.TypeOf(new(interface{}))
)
type mysqlField struct {
tableName string
name string
length uint32
flags fieldFlag
fieldType fieldType
decimals byte
charSet uint8
}
func (mf *mysqlField) scanType() reflect.Type {
switch mf.fieldType {
case fieldTypeTiny:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint8
}
return scanTypeInt8
}
return scanTypeNullInt
case fieldTypeShort, fieldTypeYear:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint16
}
return scanTypeInt16
}
return scanTypeNullInt
case fieldTypeInt24, fieldTypeLong:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint32
}
return scanTypeInt32
}
return scanTypeNullInt
case fieldTypeLongLong:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint64
}
return scanTypeInt64
}
return scanTypeNullInt
case fieldTypeFloat:
if mf.flags&flagNotNULL != 0 {
return scanTypeFloat32
}
return scanTypeNullFloat
case fieldTypeDouble:
if mf.flags&flagNotNULL != 0 {
return scanTypeFloat64
}
return scanTypeNullFloat
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
fieldTypeTime:
return scanTypeRawBytes
case fieldTypeDate, fieldTypeNewDate,
fieldTypeTimestamp, fieldTypeDateTime:
// NullTime is always returned for more consistent behavior as it can
// handle both cases of parseTime regardless if the field is nullable.
return scanTypeNullTime
default:
return scanTypeUnknown
}
}

View File

@ -0,0 +1,182 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"fmt"
"io"
"os"
"strings"
"sync"
)
var (
fileRegister map[string]bool
fileRegisterLock sync.RWMutex
readerRegister map[string]func() io.Reader
readerRegisterLock sync.RWMutex
)
// RegisterLocalFile adds the given file to the file whitelist,
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
// Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true'
//
// filePath := "/home/gopher/data.csv"
// mysql.RegisterLocalFile(filePath)
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterLocalFile(filePath string) {
fileRegisterLock.Lock()
// lazy map init
if fileRegister == nil {
fileRegister = make(map[string]bool)
}
fileRegister[strings.Trim(filePath, `"`)] = true
fileRegisterLock.Unlock()
}
// DeregisterLocalFile removes the given filepath from the whitelist.
func DeregisterLocalFile(filePath string) {
fileRegisterLock.Lock()
delete(fileRegister, strings.Trim(filePath, `"`))
fileRegisterLock.Unlock()
}
// RegisterReaderHandler registers a handler function which is used
// to receive a io.Reader.
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
// If the handler returns a io.ReadCloser Close() is called when the
// request is finished.
//
// mysql.RegisterReaderHandler("data", func() io.Reader {
// var csvReader io.Reader // Some Reader that returns CSV data
// ... // Open Reader here
// return csvReader
// })
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterReaderHandler(name string, handler func() io.Reader) {
readerRegisterLock.Lock()
// lazy map init
if readerRegister == nil {
readerRegister = make(map[string]func() io.Reader)
}
readerRegister[name] = handler
readerRegisterLock.Unlock()
}
// DeregisterReaderHandler removes the ReaderHandler function with
// the given name from the registry.
func DeregisterReaderHandler(name string) {
readerRegisterLock.Lock()
delete(readerRegister, name)
readerRegisterLock.Unlock()
}
func deferredClose(err *error, closer io.Closer) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize
}
if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
// The server might return an an absolute path. See issue #355.
name = name[idx+8:]
readerRegisterLock.RLock()
handler, inMap := readerRegister[name]
readerRegisterLock.RUnlock()
if inMap {
rdr = handler()
if rdr != nil {
if cl, ok := rdr.(io.Closer); ok {
defer deferredClose(&err, cl)
}
} else {
err = fmt.Errorf("Reader '%s' is <nil>", name)
}
} else {
err = fmt.Errorf("Reader '%s' is not registered", name)
}
} else { // File
name = strings.Trim(name, `"`)
fileRegisterLock.RLock()
fr := fileRegister[name]
fileRegisterLock.RUnlock()
if mc.cfg.AllowAllFiles || fr {
var file *os.File
var fi os.FileInfo
if file, err = os.Open(name); err == nil {
defer deferredClose(&err, file)
// get file size
if fi, err = file.Stat(); err == nil {
rdr = file
if fileSize := int(fi.Size()); fileSize < packetSize {
packetSize = fileSize
}
}
}
} else {
err = fmt.Errorf("local file '%s' is not registered", name)
}
}
// send content packets
// if packetSize == 0, the Reader contains no data
if err == nil && packetSize > 0 {
data := make([]byte, 4+packetSize)
var n int
for err == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
return ioErr
}
}
}
if err == io.EOF {
err = nil
}
}
// send empty packet (termination)
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
return ioErr
}
// read OK packet
if err == nil {
return mc.readResultOK()
}
mc.readPacket()
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,22 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlResult struct {
affectedRows int64
insertId int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil
}

View File

@ -0,0 +1,216 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"io"
"math"
"reflect"
)
type resultSet struct {
columns []mysqlField
columnNames []string
done bool
}
type mysqlRows struct {
mc *mysqlConn
rs resultSet
finish func()
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
func (rows *mysqlRows) Columns() []string {
if rows.rs.columnNames != nil {
return rows.rs.columnNames
}
columns := make([]string, len(rows.rs.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
for i := range columns {
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.rs.columns[i].name
} else {
columns[i] = rows.rs.columns[i].name
}
}
} else {
for i := range columns {
columns[i] = rows.rs.columns[i].name
}
}
rows.rs.columnNames = columns
return columns
}
func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
return rows.rs.columns[i].typeDatabaseName()
}
// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
// return int64(rows.rs.columns[i].length), true
// }
func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
return rows.rs.columns[i].flags&flagNotNULL == 0, true
}
func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) {
column := rows.rs.columns[i]
decimals := int64(column.decimals)
switch column.fieldType {
case fieldTypeDecimal, fieldTypeNewDecimal:
if decimals > 0 {
return int64(column.length) - 2, decimals, true
}
return int64(column.length) - 1, decimals, true
case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime:
return decimals, decimals, true
case fieldTypeFloat, fieldTypeDouble:
if decimals == 0x1f {
return math.MaxInt64, math.MaxInt64, true
}
return math.MaxInt64, decimals, true
}
return 0, 0, false
}
func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
return rows.rs.columns[i].scanType()
}
func (rows *mysqlRows) Close() (err error) {
if f := rows.finish; f != nil {
f()
rows.finish = nil
}
mc := rows.mc
if mc == nil {
return nil
}
if err := mc.error(); err != nil {
return err
}
// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF()
}
if err == nil {
if err = mc.discardResults(); err != nil {
return err
}
}
rows.mc = nil
return err
}
func (rows *mysqlRows) HasNextResultSet() (b bool) {
if rows.mc == nil {
return false
}
return rows.mc.status&statusMoreResultsExists != 0
}
func (rows *mysqlRows) nextResultSet() (int, error) {
if rows.mc == nil {
return 0, io.EOF
}
if err := rows.mc.error(); err != nil {
return 0, err
}
// Remove unread packets from stream
if !rows.rs.done {
if err := rows.mc.readUntilEOF(); err != nil {
return 0, err
}
rows.rs.done = true
}
if !rows.HasNextResultSet() {
rows.mc = nil
return 0, io.EOF
}
rows.rs = resultSet{}
return rows.mc.readResultSetHeaderPacket()
}
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
for {
resLen, err := rows.nextResultSet()
if err != nil {
return 0, err
}
if resLen > 0 {
return resLen, nil
}
rows.rs.done = true
}
}
func (rows *binaryRows) NextResultSet() error {
resLen, err := rows.nextNotEmptyResultSet()
if err != nil {
return err
}
rows.rs.columns, err = rows.mc.readColumns(resLen)
return err
}
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}
func (rows *textRows) NextResultSet() (err error) {
resLen, err := rows.nextNotEmptyResultSet()
if err != nil {
return err
}
rows.rs.columns, err = rows.mc.readColumns(resLen)
return err
}
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}

View File

@ -0,0 +1,211 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"fmt"
"io"
"reflect"
"strconv"
)
type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
}
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.closed.IsSet() {
// driver.Stmt.Close can be called more than once, thus this function
// has to be idempotent.
// See also Issue #450 and golang/go#16019.
//errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
stmt.mc = nil
return err
}
func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
mc.affectedRows = 0
mc.insertId = 0
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
if resLen > 0 {
// Columns
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
// Rows
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
if err := mc.discardResults(); err != nil {
return nil, err
}
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, nil
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return stmt.query(args)
}
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
if resLen > 0 {
rows.mc = mc
rows.rs.columns, err = mc.readColumns(resLen)
} else {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
return rows, err
}
type converter struct{}
// ConvertValue mirrors the reference/default converter in database/sql/driver
// with _one_ exception. We support uint64 with their high bit and the default
// implementation does not. This function should be kept in sync with
// database/sql/driver defaultConverter.ConvertValue() except for that
// deliberate difference.
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
if vr, ok := v.(driver.Valuer); ok {
sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
if !driver.IsValue(sv) {
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
return sv, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return c.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return strconv.FormatUint(u64, 10), nil
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This is an exact copy of the same-named unexported function from the
// database/sql package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}

View File

@ -0,0 +1,31 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlTx struct {
mc *mysqlConn
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.closed.IsSet() {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.closed.IsSet() {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}

View File

@ -0,0 +1,755 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// Registry for custom tls.Configs
var (
tlsConfigLock sync.RWMutex
tlsConfigRegistry map[string]*tls.Config
)
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
// Use the key as a value in the DSN where tls=value.
//
// Note: The provided tls.Config is exclusively owned by the driver after
// registering it.
//
// rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
// if err != nil {
// log.Fatal(err)
// }
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
// log.Fatal("Failed to append PEM.")
// }
// clientCert := make([]tls.Certificate, 0, 1)
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
// if err != nil {
// log.Fatal(err)
// }
// clientCert = append(clientCert, certs)
// mysql.RegisterTLSConfig("custom", &tls.Config{
// RootCAs: rootCertPool,
// Certificates: clientCert,
// })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
return fmt.Errorf("key '%s' is reserved", key)
}
tlsConfigLock.Lock()
if tlsConfigRegistry == nil {
tlsConfigRegistry = make(map[string]*tls.Config)
}
tlsConfigRegistry[key] = config
tlsConfigLock.Unlock()
return nil
}
// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
tlsConfigLock.Lock()
if tlsConfigRegistry != nil {
delete(tlsConfigRegistry, key)
}
tlsConfigLock.Unlock()
}
func getTLSConfigClone(key string) (config *tls.Config) {
tlsConfigLock.RLock()
if v, ok := tlsConfigRegistry[key]; ok {
config = v.Clone()
}
tlsConfigLock.RUnlock()
return
}
// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}
/******************************************************************************
* Time related utils *
******************************************************************************/
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}
switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
}
nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
base := "0000-00-00 00:00:00.0000000"
switch len(str) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
}
t, err = time.Parse(timeFormat[:len(str)], str)
default:
err = fmt.Errorf("invalid time string: %s", str)
return
}
// Adjust location
if err == nil && loc != time.UTC {
y, mo, d := t.Date()
h, mi, s := t.Clock()
t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
}
return
}
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
switch num {
case 0:
return time.Time{}, nil
case 4:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
0, 0, 0, 0,
loc,
), nil
case 7:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
0,
loc,
), nil
case 11:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
loc,
), nil
}
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
}
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
// if the DATE or DATETIME has the zero value.
// It must never be changed.
// The current behavior depends on database/sql copying the result.
var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
func appendMicrosecs(dst, src []byte, decimals int) []byte {
if decimals <= 0 {
return dst
}
if len(src) == 0 {
return append(dst, ".000000"[:decimals+1]...)
}
microsecs := binary.LittleEndian.Uint32(src[:4])
p1 := byte(microsecs / 10000)
microsecs -= 10000 * uint32(p1)
p2 := byte(microsecs / 100)
microsecs -= 100 * uint32(p2)
p3 := byte(microsecs)
switch decimals {
default:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3], digits01[p3],
)
case 1:
return append(dst, '.',
digits10[p1],
)
case 2:
return append(dst, '.',
digits10[p1], digits01[p1],
)
case 3:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2],
)
case 4:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
)
case 5:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3],
)
}
}
func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
// length expects the deterministic length of the zero value,
// negative time and 100+ hours are automatically added if needed
if len(src) == 0 {
return zeroDateTime[:length], nil
}
var dst []byte // return value
var p1, p2, p3 byte // current digit pair
switch length {
case 10, 19, 21, 22, 23, 24, 25, 26:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s length %d", t, length)
}
switch len(src) {
case 4, 7, 11:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
}
dst = make([]byte, 0, length)
// start with the date
year := binary.LittleEndian.Uint16(src[:2])
pt := year / 100
p1 = byte(year - 100*uint16(pt))
p2, p3 = src[2], src[3]
dst = append(dst,
digits10[pt], digits01[pt],
digits10[p1], digits01[p1], '-',
digits10[p2], digits01[p2], '-',
digits10[p3], digits01[p3],
)
if length == 10 {
return dst, nil
}
if len(src) == 4 {
return append(dst, zeroDateTime[10:length]...), nil
}
dst = append(dst, ' ')
p1 = src[4] // hour
src = src[5:]
// p1 is 2-digit hour, src is after hour
p2, p3 = src[0], src[1]
dst = append(dst,
digits10[p1], digits01[p1], ':',
digits10[p2], digits01[p2], ':',
digits10[p3], digits01[p3],
)
return appendMicrosecs(dst, src[2:], int(length)-20), nil
}
func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
// length expects the deterministic length of the zero value,
// negative time and 100+ hours are automatically added if needed
if len(src) == 0 {
return zeroDateTime[11 : 11+length], nil
}
var dst []byte // return value
switch length {
case
8, // time (can be up to 10 when negative and 100+ hours)
10, 11, 12, 13, 14, 15: // time with fractional seconds
default:
return nil, fmt.Errorf("illegal TIME length %d", length)
}
switch len(src) {
case 8, 12:
default:
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
}
// +2 to enable negative time and 100+ hours
dst = make([]byte, 0, length+2)
if src[0] == 1 {
dst = append(dst, '-')
}
days := binary.LittleEndian.Uint32(src[1:5])
hours := int64(days)*24 + int64(src[5])
if hours >= 100 {
dst = strconv.AppendInt(dst, hours, 10)
} else {
dst = append(dst, digits10[hours], digits01[hours])
}
min, sec := src[6], src[7]
dst = append(dst, ':',
digits10[min], digits01[min], ':',
digits10[sec], digits01[sec],
)
return appendMicrosecs(dst, src[8:], int(length)-9), nil
}
/******************************************************************************
* Convert from and to bytes *
******************************************************************************/
func uint64ToBytes(n uint64) []byte {
return []byte{
byte(n),
byte(n >> 8),
byte(n >> 16),
byte(n >> 24),
byte(n >> 32),
byte(n >> 40),
byte(n >> 48),
byte(n >> 56),
}
}
func uint64ToString(n uint64) []byte {
var a [20]byte
i := 20
// U+0030 = 0
// ...
// U+0039 = 9
var q uint64
for n >= 10 {
i--
q = n / 10
a[i] = uint8(n-q*10) + 0x30
n = q
}
i--
a[i] = uint8(n) + 0x30
return a[i:]
}
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, wheter the value is NULL,
// the number of bytes read and an error, in case the string is longer than
// the input slice
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
// Get length
num, isNull, n := readLengthEncodedInteger(b)
if num < 1 {
return b[n:n], isNull, n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return b[n-int(num) : n : n], false, n, nil
}
return nil, false, n, io.EOF
}
// returns the number of bytes skipped and an error, in case the string is
// longer than the input slice
func skipLengthEncodedString(b []byte) (int, error) {
// Get length
num, _, n := readLengthEncodedInteger(b)
if num < 1 {
return n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return n, nil
}
return n, io.EOF
}
// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
if len(b) == 0 {
return 0, true, 1
}
switch b[0] {
// 251: NULL
case 0xfb:
return 0, true, 1
// 252: value of following 2
case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3
// 253: value of following 3
case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
// 254: value of following 8
case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
}
// 0-250: value of first byte
return uint64(b[0]), false, 1
}
// encodes a uint64 value and appends it to the given bytes slice
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
switch {
case n <= 250:
return append(b, byte(n))
case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8))
case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
}
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
newSize := len(buf) + appendSize
if cap(buf) < newSize {
// Grow buffer exponentially
newBuf := make([]byte, len(buf)*2+appendSize)
copy(newBuf, buf)
buf = newBuf
}
return buf[:newSize]
}
// escapeBytesBackslash escapes []byte with backslashes (\)
// This escapes the contents of a string (provided as []byte) by adding backslashes before special
// characters, and turning others into specific escape sequences, such as
// turning newlines into \n and null bytes into \0.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
func escapeBytesBackslash(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
switch c {
case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0'
pos += 2
case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n'
pos += 2
case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r'
pos += 2
case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z'
pos += 2
case '\'':
buf[pos] = '\\'
buf[pos+1] = '\''
pos += 2
case '"':
buf[pos] = '\\'
buf[pos+1] = '"'
pos += 2
case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
func escapeStringBackslash(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
c := v[i]
switch c {
case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0'
pos += 2
case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n'
pos += 2
case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r'
pos += 2
case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z'
pos += 2
case '\'':
buf[pos] = '\\'
buf[pos+1] = '\''
pos += 2
case '"':
buf[pos] = '\\'
buf[pos+1] = '"'
pos += 2
case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
// This escapes the contents of a string by doubling up any apostrophes that
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
// effect on the server.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
func escapeBytesQuotes(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
func escapeStringQuotes(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
c := v[i]
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}
/******************************************************************************
* Sync utils *
******************************************************************************/
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
// for details.
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
// atomicBool is a wrapper around uint32 for usage as a boolean value with
// atomic access.
type atomicBool struct {
_noCopy noCopy
value uint32
}
// IsSet returns wether the current boolean value is true
func (ab *atomicBool) IsSet() bool {
return atomic.LoadUint32(&ab.value) > 0
}
// Set sets the value of the bool regardless of the previous value
func (ab *atomicBool) Set(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}
// TrySet sets the value of the bool and returns wether the value changed
func (ab *atomicBool) TrySet(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) == 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}
// atomicError is a wrapper for atomically accessed error values
type atomicError struct {
_noCopy noCopy
value atomic.Value
}
// Set sets the error value regardless of the previous value.
// The value must not be nil
func (ae *atomicError) Set(value error) {
ae.value.Store(value)
}
// Value returns the current error value
func (ae *atomicError) Value() error {
if v := ae.value.Load(); v != nil {
// this will panic if the value doesn't implement the error interface
return v.(error)
}
return nil
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
// TODO: support the use of Named Parameters #561
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
switch sql.IsolationLevel(level) {
case sql.LevelRepeatableRead:
return "REPEATABLE READ", nil
case sql.LevelReadCommitted:
return "READ COMMITTED", nil
case sql.LevelReadUncommitted:
return "READ UNCOMMITTED", nil
case sql.LevelSerializable:
return "SERIALIZABLE", nil
default:
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
}
}

View File

@ -0,0 +1,23 @@
Copyright (c) 2013, Jason Moiron
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,208 @@
package sqlx
import (
"bytes"
"errors"
"reflect"
"strconv"
"strings"
"github.com/jmoiron/sqlx/reflectx"
)
// Bindvar types supported by Rebind, BindMap and BindStruct.
const (
UNKNOWN = iota
QUESTION
DOLLAR
NAMED
)
// BindType returns the bindtype for a given database given a drivername.
func BindType(driverName string) int {
switch driverName {
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
return DOLLAR
case "mysql":
return QUESTION
case "sqlite3":
return QUESTION
case "oci8", "ora", "goracle":
return NAMED
}
return UNKNOWN
}
// FIXME: this should be able to be tolerant of escaped ?'s in queries without
// losing much speed, and should be to avoid confusion.
// Rebind a query from the default bindtype (QUESTION) to the target bindtype.
func Rebind(bindType int, query string) string {
switch bindType {
case QUESTION, UNKNOWN:
return query
}
// Add space enough for 10 params before we have to allocate
rqb := make([]byte, 0, len(query)+10)
var i, j int
for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") {
rqb = append(rqb, query[:i]...)
switch bindType {
case DOLLAR:
rqb = append(rqb, '$')
case NAMED:
rqb = append(rqb, ':', 'a', 'r', 'g')
}
j++
rqb = strconv.AppendInt(rqb, int64(j), 10)
query = query[i+1:]
}
return string(append(rqb, query...))
}
// Experimental implementation of Rebind which uses a bytes.Buffer. The code is
// much simpler and should be more resistant to odd unicode, but it is twice as
// slow. Kept here for benchmarking purposes and to possibly replace Rebind if
// problems arise with its somewhat naive handling of unicode.
func rebindBuff(bindType int, query string) string {
if bindType != DOLLAR {
return query
}
b := make([]byte, 0, len(query))
rqb := bytes.NewBuffer(b)
j := 1
for _, r := range query {
if r == '?' {
rqb.WriteRune('$')
rqb.WriteString(strconv.Itoa(j))
j++
} else {
rqb.WriteRune(r)
}
}
return rqb.String()
}
// In expands slice values in args, returning the modified query string
// and a new arg list that can be executed by a database. The `query` should
// use the `?` bindVar. The return value uses the `?` bindVar.
func In(query string, args ...interface{}) (string, []interface{}, error) {
// argMeta stores reflect.Value and length for slices and
// the value itself for non-slice arguments
type argMeta struct {
v reflect.Value
i interface{}
length int
}
var flatArgsCount int
var anySlices bool
meta := make([]argMeta, len(args))
for i, arg := range args {
v := reflect.ValueOf(arg)
t := reflectx.Deref(v.Type())
// []byte is a driver.Value type so it should not be expanded
if t.Kind() == reflect.Slice && t != reflect.TypeOf([]byte{}) {
meta[i].length = v.Len()
meta[i].v = v
anySlices = true
flatArgsCount += meta[i].length
if meta[i].length == 0 {
return "", nil, errors.New("empty slice passed to 'in' query")
}
} else {
meta[i].i = arg
flatArgsCount++
}
}
// don't do any parsing if there aren't any slices; note that this means
// some errors that we might have caught below will not be returned.
if !anySlices {
return query, args, nil
}
newArgs := make([]interface{}, 0, flatArgsCount)
buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount))
var arg, offset int
for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') {
if arg >= len(meta) {
// if an argument wasn't passed, lets return an error; this is
// not actually how database/sql Exec/Query works, but since we are
// creating an argument list programmatically, we want to be able
// to catch these programmer errors earlier.
return "", nil, errors.New("number of bindVars exceeds arguments")
}
argMeta := meta[arg]
arg++
// not a slice, continue.
// our questionmark will either be written before the next expansion
// of a slice or after the loop when writing the rest of the query
if argMeta.length == 0 {
offset = offset + i + 1
newArgs = append(newArgs, argMeta.i)
continue
}
// write everything up to and including our ? character
buf.WriteString(query[:offset+i+1])
for si := 1; si < argMeta.length; si++ {
buf.WriteString(", ?")
}
newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length)
// slice the query and reset the offset. this avoids some bookkeeping for
// the write after the loop
query = query[offset+i+1:]
offset = 0
}
buf.WriteString(query)
if arg < len(meta) {
return "", nil, errors.New("number of bindVars less than number arguments")
}
return buf.String(), newArgs, nil
}
func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} {
switch val := v.Interface().(type) {
case []interface{}:
args = append(args, val...)
case []int:
for i := range val {
args = append(args, val[i])
}
case []string:
for i := range val {
args = append(args, val[i])
}
default:
for si := 0; si < vlen; si++ {
args = append(args, v.Index(si).Interface())
}
}
return args
}

View File

@ -0,0 +1,12 @@
// Package sqlx provides general purpose extensions to database/sql.
//
// It is intended to seamlessly wrap database/sql and provide convenience
// methods which are useful in the development of database driven applications.
// None of the underlying database/sql methods are changed. Instead all extended
// behavior is implemented through new methods defined on wrapper types.
//
// Additions include scanning into structs, named query support, rebinding
// queries for different drivers, convenient shorthands for common error handling
// and more.
//
package sqlx

View File

@ -0,0 +1,346 @@
package sqlx
// Named Query Support
//
// * BindMap - bind query bindvars to map/struct args
// * NamedExec, NamedQuery - named query w/ struct or map
// * NamedStmt - a pre-compiled named query which is a prepared statement
//
// Internal Interfaces:
//
// * compileNamedQuery - rebind a named query, returning a query and list of names
// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
//
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
"unicode"
"github.com/jmoiron/sqlx/reflectx"
)
// NamedStmt is a prepared statement that executes named queries. Prepare it
// how you would execute a NamedQuery, but pass in a struct or map when executing.
type NamedStmt struct {
Params []string
QueryString string
Stmt *Stmt
}
// Close closes the named statement.
func (n *NamedStmt) Close() error {
return n.Stmt.Close()
}
// Exec executes a named statement using the struct passed.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
return n.Stmt.Exec(args...)
}
// Query executes a named statement using the struct argument, returning rows.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
return n.Stmt.Query(args...)
}
// QueryRow executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRow(arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
return n.Stmt.QueryRowx(args...)
}
// MustExec execs a NamedStmt, panicing on error
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) MustExec(arg interface{}) sql.Result {
res, err := n.Exec(arg)
if err != nil {
panic(err)
}
return res
}
// Queryx using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
r, err := n.Query(arg)
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
}
// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRowx(arg interface{}) *Row {
return n.QueryRow(arg)
}
// Select using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
rows, err := n.Queryx(arg)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// Get using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) Get(dest interface{}, arg interface{}) error {
r := n.QueryRowx(arg)
return r.scanAny(dest, false)
}
// Unsafe creates an unsafe version of the NamedStmt
func (n *NamedStmt) Unsafe() *NamedStmt {
r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString}
r.Stmt.unsafe = true
return r
}
// A union interface of preparer and binder, required to be able to prepare
// named statements (as the bindtype must be determined).
type namedPreparer interface {
Preparer
binder
}
func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) {
bindType := BindType(p.DriverName())
q, args, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return nil, err
}
stmt, err := Preparex(p, q)
if err != nil {
return nil, err
}
return &NamedStmt{
QueryString: q,
Params: args,
Stmt: stmt,
}, nil
}
func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMapArgs(names, maparg)
}
return bindArgs(names, arg, m)
}
// private interface to generate a list of interfaces from a given struct
// type, given a list of names to pull out of the struct. Used by public
// BindStruct interface.
func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
// grab the indirected value of arg
v := reflect.ValueOf(arg)
for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; {
v = v.Elem()
}
err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error {
if len(t) == 0 {
return fmt.Errorf("could not find name %s in %#v", names[i], arg)
}
val := reflectx.FieldByIndexesReadOnly(v, t)
arglist = append(arglist, val.Interface())
return nil
})
return arglist, err
}
// like bindArgs, but for maps.
func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
for _, name := range names {
val, ok := arg[name]
if !ok {
return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
}
arglist = append(arglist, val)
}
return arglist, nil
}
// bindStruct binds a named parameter query with fields from a struct argument.
// The rules for binding field names to parameter names follow the same
// conventions as for StructScan, including obeying the `db` struct tags.
func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindArgs(names, arg, m)
if err != nil {
return "", []interface{}{}, err
}
return bound, arglist, nil
}
// bindMap binds a named parameter query with a map of arguments.
func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindMapArgs(names, args)
return bound, arglist, err
}
// -- Compilation of Named Queries
// Allow digits and letters in bind params; additionally runes are
// checked against underscores, meaning that bind params can have be
// alphanumeric with underscores. Mind the difference between unicode
// digits and numbers, where '5' is a digit but '五' is not.
var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}
// FIXME: this function isn't safe for unicode named params, as a failing test
// can testify. This is not a regression but a failure of the original code
// as well. It should be modified to range over runes in a string rather than
// bytes, even though this is less convenient and slower. Hopefully the
// addition of the prepared NamedStmt (which will only do this once) will make
// up for the slightly slower ad-hoc NamedExec/NamedQuery.
// compile a NamedQuery into an unbound query (using the '?' bindvar) and
// a list of names.
func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
names = make([]string, 0, 10)
rebound := make([]byte, 0, len(qs))
inName := false
last := len(qs) - 1
currentVar := 1
name := make([]byte, 0, 10)
for i, b := range qs {
// a ':' while we're in a name is an error
if b == ':' {
// if this is the second ':' in a '::' escape sequence, append a ':'
if inName && i > 0 && qs[i-1] == ':' {
rebound = append(rebound, ':')
inName = false
continue
} else if inName {
err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
return query, names, err
}
inName = true
name = []byte{}
// if we're in a name, and this is an allowed character, continue
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
// append the byte to the name if we are in a name and not on the last byte
name = append(name, b)
// if we're in a name and it's not an allowed character, the name is done
} else if inName {
inName = false
// if this is the final byte of the string and it is part of the name, then
// make sure to add it to the name
if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
name = append(name, b)
}
// add the string representation to the names list
names = append(names, string(name))
// add a proper bindvar for the bindType
switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
rebound = append(rebound, ':')
rebound = append(rebound, name...)
case QUESTION, UNKNOWN:
rebound = append(rebound, '?')
case DOLLAR:
rebound = append(rebound, '$')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
}
currentVar++
}
// add this byte to string unless it was not part of the name
if i != last {
rebound = append(rebound, b)
} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
rebound = append(rebound, b)
}
} else {
// this is a normal byte and should just go onto the rebound query
rebound = append(rebound, b)
}
}
return string(rebound), names, err
}
// BindNamed binds a struct or a map to a query with named parameters.
// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future.
func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(bindType, query, arg, mapper())
}
// Named takes a query using named parameters and an argument and
// returns a new query with a list of args that can be executed by
// a database. The return value uses the `?` bindvar.
func Named(query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(QUESTION, query, arg, mapper())
}
func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMap(bindType, query, maparg)
}
return bindStruct(bindType, query, arg, m)
}
// NamedQuery binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Queryx(q, args...)
}
// NamedExec uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
// or the query excution itself.
func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Exec(q, args...)
}

View File

@ -0,0 +1,132 @@
// +build go1.8
package sqlx
import (
"context"
"database/sql"
)
// A union interface of contextPreparer and binder, required to be able to
// prepare named statements with context (as the bindtype must be determined).
type namedPreparerContext interface {
PreparerContext
binder
}
func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) {
bindType := BindType(p.DriverName())
q, args, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return nil, err
}
stmt, err := PreparexContext(ctx, p, q)
if err != nil {
return nil, err
}
return &NamedStmt{
QueryString: q,
Params: args,
Stmt: stmt,
}, nil
}
// ExecContext executes a named statement using the struct passed.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
return n.Stmt.ExecContext(ctx, args...)
}
// QueryContext executes a named statement using the struct argument, returning rows.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
return n.Stmt.QueryContext(ctx, args...)
}
// QueryRowContext executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
return n.Stmt.QueryRowxContext(ctx, args...)
}
// MustExecContext execs a NamedStmt, panicing on error
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result {
res, err := n.ExecContext(ctx, arg)
if err != nil {
panic(err)
}
return res
}
// QueryxContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) {
r, err := n.QueryContext(ctx, arg)
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
}
// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow.
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row {
return n.QueryRowContext(ctx, arg)
}
// SelectContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error {
rows, err := n.QueryxContext(ctx, arg)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// GetContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error {
r := n.QueryRowxContext(ctx, arg)
return r.scanAny(dest, false)
}
// NamedQueryContext binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.QueryxContext(ctx, q, args...)
}
// NamedExecContext uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
// or the query excution itself.
func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.ExecContext(ctx, q, args...)
}

View File

@ -0,0 +1,441 @@
// Package reflectx implements extensions to the standard reflect lib suitable
// for implementing marshalling and unmarshalling packages. The main Mapper type
// allows for Go-compatible named attribute access, including accessing embedded
// struct attributes and the ability to use functions and struct tags to
// customize field names.
//
package reflectx
import (
"reflect"
"runtime"
"strings"
"sync"
)
// A FieldInfo is metadata for a struct field.
type FieldInfo struct {
Index []int
Path string
Field reflect.StructField
Zero reflect.Value
Name string
Options map[string]string
Embedded bool
Children []*FieldInfo
Parent *FieldInfo
}
// A StructMap is an index of field metadata for a struct.
type StructMap struct {
Tree *FieldInfo
Index []*FieldInfo
Paths map[string]*FieldInfo
Names map[string]*FieldInfo
}
// GetByPath returns a *FieldInfo for a given string path.
func (f StructMap) GetByPath(path string) *FieldInfo {
return f.Paths[path]
}
// GetByTraversal returns a *FieldInfo for a given integer path. It is
// analogous to reflect.FieldByIndex, but using the cached traversal
// rather than re-executing the reflect machinery each time.
func (f StructMap) GetByTraversal(index []int) *FieldInfo {
if len(index) == 0 {
return nil
}
tree := f.Tree
for _, i := range index {
if i >= len(tree.Children) || tree.Children[i] == nil {
return nil
}
tree = tree.Children[i]
}
return tree
}
// Mapper is a general purpose mapper of names to struct fields. A Mapper
// behaves like most marshallers in the standard library, obeying a field tag
// for name mapping but also providing a basic transform function.
type Mapper struct {
cache map[reflect.Type]*StructMap
tagName string
tagMapFunc func(string) string
mapFunc func(string) string
mutex sync.Mutex
}
// NewMapper returns a new mapper using the tagName as its struct field tag.
// If tagName is the empty string, it is ignored.
func NewMapper(tagName string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
}
}
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
// AND a mapper for tag values. This is useful for tags like json which can
// have values like "name,omitempty".
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: mapFunc,
tagMapFunc: tagMapFunc,
}
}
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
// a struct field name mapper func given by f. Tags will take precedence, but
// for any other field, the mapped name will be f(field.Name)
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: f,
}
}
// TypeMap returns a mapping of field strings to int slices representing
// the traversal down the struct to reach the field.
func (m *Mapper) TypeMap(t reflect.Type) *StructMap {
m.mutex.Lock()
mapping, ok := m.cache[t]
if !ok {
mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
m.cache[t] = mapping
}
m.mutex.Unlock()
return mapping
}
// FieldMap returns the mapper's mapping of field names to reflect values. Panics
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
r := map[string]reflect.Value{}
tm := m.TypeMap(v.Type())
for tagName, fi := range tm.Names {
r[tagName] = FieldByIndexes(v, fi.Index)
}
return r
}
// FieldByName returns a field by its mapped name as a reflect.Value.
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
// Returns zero Value if the name is not found.
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
fi, ok := tm.Names[name]
if !ok {
return v
}
return FieldByIndexes(v, fi.Index)
}
// FieldsByName returns a slice of values corresponding to the slice of names
// for the value. Panics if v's Kind is not Struct or v is not Indirectable
// to a struct Kind. Returns zero Value for each name not found.
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
vals := make([]reflect.Value, 0, len(names))
for _, name := range names {
fi, ok := tm.Names[name]
if !ok {
vals = append(vals, *new(reflect.Value))
} else {
vals = append(vals, FieldByIndexes(v, fi.Index))
}
}
return vals
}
// TraversalsByName returns a slice of int slices which represent the struct
// traversals for each mapped name. Panics if t is not a struct or Indirectable
// to a struct. Returns empty int slice for each name not found.
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
r := make([][]int, 0, len(names))
m.TraversalsByNameFunc(t, names, func(_ int, i []int) error {
if i == nil {
r = append(r, []int{})
} else {
r = append(r, i)
}
return nil
})
return r
}
// TraversalsByNameFunc traverses the mapped names and calls fn with the index of
// each name and the struct traversal represented by that name. Panics if t is not
// a struct or Indirectable to a struct. Returns the first error returned by fn or nil.
func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error {
t = Deref(t)
mustBe(t, reflect.Struct)
tm := m.TypeMap(t)
for i, name := range names {
fi, ok := tm.Names[name]
if !ok {
if err := fn(i, nil); err != nil {
return err
}
} else {
if err := fn(i, fi.Index); err != nil {
return err
}
}
}
return nil
}
// FieldByIndexes returns a value for the field given by the struct traversal
// for the given value.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
// if this is a pointer and it's nil, allocate a new value and set it
if v.Kind() == reflect.Ptr && v.IsNil() {
alloc := reflect.New(Deref(v.Type()))
v.Set(alloc)
}
if v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}
}
return v
}
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
// but is not concerned with allocating nil pointers because the value is
// going to be used for reading and not setting.
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
}
return v
}
// Deref is Indirect for reflect.Types
func Deref(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
// -- helpers & utilities --
type kinder interface {
Kind() reflect.Kind
}
// mustBe checks a value against a kind, panicing with a reflect.ValueError
// if the kind isn't that which is required.
func mustBe(v kinder, expected reflect.Kind) {
if k := v.Kind(); k != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: k})
}
}
// methodName returns the caller of the function calling methodName
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}
type typeQueue struct {
t reflect.Type
fi *FieldInfo
pp string // Parent path
}
// A copying append that creates a new slice each time.
func apnd(is []int, i int) []int {
x := make([]int, len(is)+1)
for p, n := range is {
x[p] = n
}
x[len(x)-1] = i
return x
}
type mapf func(string) string
// parseName parses the tag and the target name for the given field using
// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the
// field's name to a target name, and tagMapFunc for mapping the tag to
// a target name.
func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) {
// first, set the fieldName to the field's name
fieldName = field.Name
// if a mapFunc is set, use that to override the fieldName
if mapFunc != nil {
fieldName = mapFunc(fieldName)
}
// if there's no tag to look for, return the field name
if tagName == "" {
return "", fieldName
}
// if this tag is not set using the normal convention in the tag,
// then return the fieldname.. this check is done because according
// to the reflect documentation:
// If the tag does not have the conventional format,
// the value returned by Get is unspecified.
// which doesn't sound great.
if !strings.Contains(string(field.Tag), tagName+":") {
return "", fieldName
}
// at this point we're fairly sure that we have a tag, so lets pull it out
tag = field.Tag.Get(tagName)
// if we have a mapper function, call it on the whole tag
// XXX: this is a change from the old version, which pulled out the name
// before the tagMapFunc could be run, but I think this is the right way
if tagMapFunc != nil {
tag = tagMapFunc(tag)
}
// finally, split the options from the name
parts := strings.Split(tag, ",")
fieldName = parts[0]
return tag, fieldName
}
// parseOptions parses options out of a tag string, skipping the name
func parseOptions(tag string) map[string]string {
parts := strings.Split(tag, ",")
options := make(map[string]string, len(parts))
if len(parts) > 1 {
for _, opt := range parts[1:] {
// short circuit potentially expensive split op
if strings.Contains(opt, "=") {
kv := strings.Split(opt, "=")
options[kv[0]] = kv[1]
continue
}
options[opt] = ""
}
}
return options
}
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
// tagMapFunc to determine the canonical names of fields.
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap {
m := []*FieldInfo{}
root := &FieldInfo{}
queue := []typeQueue{}
queue = append(queue, typeQueue{Deref(t), root, ""})
QueueLoop:
for len(queue) != 0 {
// pop the first item off of the queue
tq := queue[0]
queue = queue[1:]
// ignore recursive field
for p := tq.fi.Parent; p != nil; p = p.Parent {
if tq.fi.Field.Type == p.Field.Type {
continue QueueLoop
}
}
nChildren := 0
if tq.t.Kind() == reflect.Struct {
nChildren = tq.t.NumField()
}
tq.fi.Children = make([]*FieldInfo, nChildren)
// iterate through all of its fields
for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
f := tq.t.Field(fieldPos)
// parse the tag and the target name using the mapping options for this field
tag, name := parseName(f, tagName, mapFunc, tagMapFunc)
// if the name is "-", disabled via a tag, skip it
if name == "-" {
continue
}
fi := FieldInfo{
Field: f,
Name: name,
Zero: reflect.New(f.Type).Elem(),
Options: parseOptions(tag),
}
// if the path is empty this path is just the name
if tq.pp == "" {
fi.Path = fi.Name
} else {
fi.Path = tq.pp + "." + fi.Name
}
// skip unexported fields
if len(f.PkgPath) != 0 && !f.Anonymous {
continue
}
// bfs search of anonymous embedded structs
if f.Anonymous {
pp := tq.pp
if tag != "" {
pp = fi.Path
}
fi.Embedded = true
fi.Index = apnd(tq.fi.Index, fieldPos)
nChildren := 0
ft := Deref(f.Type)
if ft.Kind() == reflect.Struct {
nChildren = ft.NumField()
}
fi.Children = make([]*FieldInfo, nChildren)
queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
}
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Parent = tq.fi
tq.fi.Children[fieldPos] = &fi
m = append(m, &fi)
}
}
flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}}
for _, fi := range flds.Index {
flds.Paths[fi.Path] = fi
if fi.Name != "" && !fi.Embedded {
flds.Names[fi.Path] = fi
}
}
return flds
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,348 @@
// +build go1.8
package sqlx
import (
"context"
"database/sql"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
)
// ConnectContext to a database and verify with a ping.
func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) {
db, err := Open(driverName, dataSourceName)
if err != nil {
return db, err
}
err = db.PingContext(ctx)
return db, err
}
// QueryerContext is an interface used by GetContext and SelectContext
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row
}
// PreparerContext is an interface used by PreparexContext.
type PreparerContext interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// ExecerContext is an interface used by MustExecContext and LoadFileContext
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
// ExtContext is a union interface which can bind, query, and exec, with Context
// used by NamedQueryContext and NamedExecContext.
type ExtContext interface {
binder
QueryerContext
ExecerContext
}
// SelectContext executes a query using the provided Queryer, and StructScans
// each row into dest, which must be a slice. If the slice elements are
// scannable, then the result set must have only one column. Otherwise,
// StructScan is used. The *sql.Rows are closed automatically.
// Any placeholder parameters are replaced with supplied args.
func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
rows, err := q.QueryxContext(ctx, query, args...)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// PreparexContext prepares a statement.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) {
s, err := p.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err
}
// GetContext does a QueryRow using the provided Queryer, and scans the
// resulting row to dest. If dest is scannable, the result must only have one
// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like
// row.Scan would. Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowxContext(ctx, query, args...)
return r.scanAny(dest, false)
}
// LoadFileContext exec's every statement in a file (as a single call to Exec).
// LoadFileContext may return a nil *sql.Result if errors are encountered
// locating or reading the file at path. LoadFile reads the entire file into
// memory, so it is not suitable for loading large data dumps, but can be useful
// for initializing schemas or loading indexes.
//
// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3
// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting
// this by requiring something with DriverName() and then attempting to split the
// queries will be difficult to get right, and its current driver-specific behavior
// is deemed at least not complex in its incorrectness.
func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) {
realpath, err := filepath.Abs(path)
if err != nil {
return nil, err
}
contents, err := ioutil.ReadFile(realpath)
if err != nil {
return nil, err
}
res, err := e.ExecContext(ctx, string(contents))
return &res, err
}
// MustExecContext execs the query using e and panics if there was an error.
// Any placeholder parameters are replaced with supplied args.
func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result {
res, err := e.ExecContext(ctx, query, args...)
if err != nil {
panic(err)
}
return res
}
// PrepareNamedContext returns an sqlx.NamedStmt
func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) {
return prepareNamedContext(ctx, db, query)
}
// NamedQueryContext using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) {
return NamedQueryContext(ctx, db, query, arg)
}
// NamedExecContext using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return NamedExecContext(ctx, db, query, arg)
}
// SelectContext using this DB.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return SelectContext(ctx, db, dest, query, args...)
}
// GetContext using this DB.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return GetContext(ctx, db, dest, query, args...)
}
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
return PreparexContext(ctx, db, query)
}
// QueryxContext queries the database and returns an *sqlx.Rows.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := db.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
}
// QueryRowxContext queries the database and returns an *sqlx.Row.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.DB.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
}
// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead
// of an *sql.Tx.
//
// The provided context is used until the transaction is committed or rolled
// back. If the context is canceled, the sql package will roll back the
// transaction. Tx.Commit will return an error if the context provided to
// MustBeginContext is canceled.
func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx {
tx, err := db.BeginTxx(ctx, opts)
if err != nil {
panic(err)
}
return tx
}
// MustExecContext (panic) runs MustExec using this database.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
return MustExecContext(ctx, db, query, args...)
}
// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an
// *sql.Tx.
//
// The provided context is used until the transaction is committed or rolled
// back. If the context is canceled, the sql package will roll back the
// transaction. Tx.Commit will return an error if the context provided to
// BeginxContext is canceled.
func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
}
// StmtxContext returns a version of the prepared statement which runs within a
// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt.
func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt {
var s *sql.Stmt
switch v := stmt.(type) {
case Stmt:
s = v.Stmt
case *Stmt:
s = v.Stmt
case sql.Stmt:
s = &v
case *sql.Stmt:
s = v
default:
panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type()))
}
return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper}
}
// NamedStmtContext returns a version of the prepared statement which runs
// within a transaction.
func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt {
return &NamedStmt{
QueryString: stmt.QueryString,
Params: stmt.Params,
Stmt: tx.StmtxContext(ctx, stmt.Stmt),
}
}
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (tx *Tx) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
return PreparexContext(ctx, tx, query)
}
// PrepareNamedContext returns an sqlx.NamedStmt
func (tx *Tx) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) {
return prepareNamedContext(ctx, tx, query)
}
// MustExecContext runs MustExecContext within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
return MustExecContext(ctx, tx, query, args...)
}
// QueryxContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := tx.Tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
}
// SelectContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return SelectContext(ctx, tx, dest, query, args...)
}
// GetContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return GetContext(ctx, tx, dest, query, args...)
}
// QueryRowxContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.Tx.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
}
// NamedExecContext using this Tx.
// Any named placeholder parameters are replaced with fields from arg.
func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return NamedExecContext(ctx, tx, query, arg)
}
// SelectContext using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error {
return SelectContext(ctx, &qStmt{s}, dest, "", args...)
}
// GetContext using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error {
return GetContext(ctx, &qStmt{s}, dest, "", args...)
}
// MustExecContext (panic) using this statement. Note that the query portion of
// the error output will be blank, as Stmt does not expose its query.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result {
return MustExecContext(ctx, &qStmt{s}, "", args...)
}
// QueryRowxContext using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row {
qs := &qStmt{s}
return qs.QueryRowxContext(ctx, "", args...)
}
// QueryxContext using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) {
qs := &qStmt{s}
return qs.QueryxContext(ctx, "", args...)
}
func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return q.Stmt.QueryContext(ctx, args...)
}
func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := q.Stmt.QueryContext(ctx, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
}
func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := q.Stmt.QueryContext(ctx, args...)
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}
}
func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return q.Stmt.ExecContext(ctx, args...)
}

View File

@ -0,0 +1,191 @@
All files in this repository are licensed as follows. If you contribute
to this repository, it is assumed that you license your contribution
under the same license unless you state otherwise.
All files Copyright (C) 2015 Canonical Ltd. unless otherwise specified in the file.
This software is licensed under the LGPLv3, included below.
As a special exception to the GNU Lesser General Public License version 3
("LGPL3"), the copyright holders of this Library give you permission to
convey to a third party a Combined Work that links statically or dynamically
to this Library without providing any Minimal Corresponding Source or
Minimal Application Code as set out in 4d or providing the installation
information set out in section 4e, provided that you comply with the other
provisions of LGPL3 and provided that you meet, for the Application the
terms and conditions of the license(s) which apply to the Application.
Except as stated in this special exception, the provisions of LGPL3 will
continue to comply in full to this Library. If you modify this Library, you
may apply this exception to your version of this Library, but you are not
obliged to do so. If you do not wish to do so, delete this exception
statement from your version. This exception does not (and cannot) modify any
license terms which apply to the Application, with which you must still
comply.
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.

View File

@ -0,0 +1,81 @@
// Copyright 2013, 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
/*
[godoc-link-here]
The juju/errors provides an easy way to annotate errors without losing the
orginal error context.
The exported `New` and `Errorf` functions are designed to replace the
`errors.New` and `fmt.Errorf` functions respectively. The same underlying
error is there, but the package also records the location at which the error
was created.
A primary use case for this library is to add extra context any time an
error is returned from a function.
if err := SomeFunc(); err != nil {
return err
}
This instead becomes:
if err := SomeFunc(); err != nil {
return errors.Trace(err)
}
which just records the file and line number of the Trace call, or
if err := SomeFunc(); err != nil {
return errors.Annotate(err, "more context")
}
which also adds an annotation to the error.
When you want to check to see if an error is of a particular type, a helper
function is normally exported by the package that returned the error, like the
`os` package does. The underlying cause of the error is available using the
`Cause` function.
os.IsNotExist(errors.Cause(err))
The result of the `Error()` call on an annotated error is the annotations joined
with colons, then the result of the `Error()` method for the underlying error
that was the cause.
err := errors.Errorf("original")
err = errors.Annotatef(err, "context")
err = errors.Annotatef(err, "more context")
err.Error() -> "more context: context: original"
Obviously recording the file, line and functions is not very useful if you
cannot get them back out again.
errors.ErrorStack(err)
will return something like:
first error
github.com/juju/errors/annotation_test.go:193:
github.com/juju/errors/annotation_test.go:194: annotation
github.com/juju/errors/annotation_test.go:195:
github.com/juju/errors/annotation_test.go:196: more context
github.com/juju/errors/annotation_test.go:197:
The first error was generated by an external system, so there was no location
associated. The second, fourth, and last lines were generated with Trace calls,
and the other two through Annotate.
Sometimes when responding to an error you want to return a more specific error
for the situation.
if err := FindField(field); err != nil {
return errors.Wrap(err, errors.NotFoundf(field))
}
This returns an error where the complete error stack is still available, and
`errors.Cause()` will return the `NotFound` error.
*/
package errors

View File

@ -0,0 +1,172 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"fmt"
"reflect"
"runtime"
)
// Err holds a description of an error along with information about
// where the error was created.
//
// It may be embedded in custom error types to add extra information that
// this errors package can understand.
type Err struct {
// message holds an annotation of the error.
message string
// cause holds the cause of the error as returned
// by the Cause method.
cause error
// previous holds the previous error in the error stack, if any.
previous error
// file and line hold the source code location where the error was
// created.
file string
line int
}
// NewErr is used to return an Err for the purpose of embedding in other
// structures. The location is not specified, and needs to be set with a call
// to SetLocation.
//
// For example:
// type FooError struct {
// errors.Err
// code int
// }
//
// func NewFooError(code int) error {
// err := &FooError{errors.NewErr("foo"), code}
// err.SetLocation(1)
// return err
// }
func NewErr(format string, args ...interface{}) Err {
return Err{
message: fmt.Sprintf(format, args...),
}
}
// NewErrWithCause is used to return an Err with case by other error for the purpose of embedding in other
// structures. The location is not specified, and needs to be set with a call
// to SetLocation.
//
// For example:
// type FooError struct {
// errors.Err
// code int
// }
//
// func (e *FooError) Annotate(format string, args ...interface{}) error {
// err := &FooError{errors.NewErrWithCause(e.Err, format, args...), e.code}
// err.SetLocation(1)
// return err
// })
func NewErrWithCause(other error, format string, args ...interface{}) Err {
return Err{
message: fmt.Sprintf(format, args...),
cause: Cause(other),
previous: other,
}
}
// Location is the file and line of where the error was most recently
// created or annotated.
func (e *Err) Location() (filename string, line int) {
return e.file, e.line
}
// Underlying returns the previous error in the error stack, if any. A client
// should not ever really call this method. It is used to build the error
// stack and should not be introspected by client calls. Or more
// specifically, clients should not depend on anything but the `Cause` of an
// error.
func (e *Err) Underlying() error {
return e.previous
}
// The Cause of an error is the most recent error in the error stack that
// meets one of these criteria: the original error that was raised; the new
// error that was passed into the Wrap function; the most recently masked
// error; or nil if the error itself is considered the Cause. Normally this
// method is not invoked directly, but instead through the Cause stand alone
// function.
func (e *Err) Cause() error {
return e.cause
}
// Message returns the message stored with the most recent location. This is
// the empty string if the most recent call was Trace, or the message stored
// with Annotate or Mask.
func (e *Err) Message() string {
return e.message
}
// Error implements error.Error.
func (e *Err) Error() string {
// We want to walk up the stack of errors showing the annotations
// as long as the cause is the same.
err := e.previous
if !sameError(Cause(err), e.cause) && e.cause != nil {
err = e.cause
}
switch {
case err == nil:
return e.message
case e.message == "":
return err.Error()
}
return fmt.Sprintf("%s: %v", e.message, err)
}
// Format implements fmt.Formatter
// When printing errors with %+v it also prints the stack trace.
// %#v unsurprisingly will print the real underlying type.
func (e *Err) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
fmt.Fprintf(s, "%s", ErrorStack(e))
return
case s.Flag('#'):
// avoid infinite recursion by wrapping e into a type
// that doesn't implement Formatter.
fmt.Fprintf(s, "%#v", (*unformatter)(e))
return
}
fallthrough
case 's':
fmt.Fprintf(s, "%s", e.Error())
}
}
// helper for Format
type unformatter Err
func (unformatter) Format() { /* break the fmt.Formatter interface */ }
// SetLocation records the source location of the error at callDepth stack
// frames above the call.
func (e *Err) SetLocation(callDepth int) {
_, file, line, _ := runtime.Caller(callDepth + 1)
e.file = trimGoPath(file)
e.line = line
}
// StackTrace returns one string for each location recorded in the stack of
// errors. The first value is the originating error, with a line for each
// other annotation or tracing of the error.
func (e *Err) StackTrace() []string {
return errorStack(e)
}
// Ideally we'd have a way to check identity, but deep equals will do.
func sameError(e1, e2 error) bool {
return reflect.DeepEqual(e1, e2)
}

View File

@ -0,0 +1,309 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"fmt"
)
// wrap is a helper to construct an *wrapper.
func wrap(err error, format, suffix string, args ...interface{}) Err {
newErr := Err{
message: fmt.Sprintf(format+suffix, args...),
previous: err,
}
newErr.SetLocation(2)
return newErr
}
// notFound represents an error when something has not been found.
type notFound struct {
Err
}
// NotFoundf returns an error which satisfies IsNotFound().
func NotFoundf(format string, args ...interface{}) error {
return &notFound{wrap(nil, format, " not found", args...)}
}
// NewNotFound returns an error which wraps err that satisfies
// IsNotFound().
func NewNotFound(err error, msg string) error {
return &notFound{wrap(err, msg, "")}
}
// IsNotFound reports whether err was created with NotFoundf() or
// NewNotFound().
func IsNotFound(err error) bool {
err = Cause(err)
_, ok := err.(*notFound)
return ok
}
// userNotFound represents an error when an inexistent user is looked up.
type userNotFound struct {
Err
}
// UserNotFoundf returns an error which satisfies IsUserNotFound().
func UserNotFoundf(format string, args ...interface{}) error {
return &userNotFound{wrap(nil, format, " user not found", args...)}
}
// NewUserNotFound returns an error which wraps err and satisfies
// IsUserNotFound().
func NewUserNotFound(err error, msg string) error {
return &userNotFound{wrap(err, msg, "")}
}
// IsUserNotFound reports whether err was created with UserNotFoundf() or
// NewUserNotFound().
func IsUserNotFound(err error) bool {
err = Cause(err)
_, ok := err.(*userNotFound)
return ok
}
// unauthorized represents an error when an operation is unauthorized.
type unauthorized struct {
Err
}
// Unauthorizedf returns an error which satisfies IsUnauthorized().
func Unauthorizedf(format string, args ...interface{}) error {
return &unauthorized{wrap(nil, format, "", args...)}
}
// NewUnauthorized returns an error which wraps err and satisfies
// IsUnauthorized().
func NewUnauthorized(err error, msg string) error {
return &unauthorized{wrap(err, msg, "")}
}
// IsUnauthorized reports whether err was created with Unauthorizedf() or
// NewUnauthorized().
func IsUnauthorized(err error) bool {
err = Cause(err)
_, ok := err.(*unauthorized)
return ok
}
// notImplemented represents an error when something is not
// implemented.
type notImplemented struct {
Err
}
// NotImplementedf returns an error which satisfies IsNotImplemented().
func NotImplementedf(format string, args ...interface{}) error {
return &notImplemented{wrap(nil, format, " not implemented", args...)}
}
// NewNotImplemented returns an error which wraps err and satisfies
// IsNotImplemented().
func NewNotImplemented(err error, msg string) error {
return &notImplemented{wrap(err, msg, "")}
}
// IsNotImplemented reports whether err was created with
// NotImplementedf() or NewNotImplemented().
func IsNotImplemented(err error) bool {
err = Cause(err)
_, ok := err.(*notImplemented)
return ok
}
// alreadyExists represents and error when something already exists.
type alreadyExists struct {
Err
}
// AlreadyExistsf returns an error which satisfies IsAlreadyExists().
func AlreadyExistsf(format string, args ...interface{}) error {
return &alreadyExists{wrap(nil, format, " already exists", args...)}
}
// NewAlreadyExists returns an error which wraps err and satisfies
// IsAlreadyExists().
func NewAlreadyExists(err error, msg string) error {
return &alreadyExists{wrap(err, msg, "")}
}
// IsAlreadyExists reports whether the error was created with
// AlreadyExistsf() or NewAlreadyExists().
func IsAlreadyExists(err error) bool {
err = Cause(err)
_, ok := err.(*alreadyExists)
return ok
}
// notSupported represents an error when something is not supported.
type notSupported struct {
Err
}
// NotSupportedf returns an error which satisfies IsNotSupported().
func NotSupportedf(format string, args ...interface{}) error {
return &notSupported{wrap(nil, format, " not supported", args...)}
}
// NewNotSupported returns an error which wraps err and satisfies
// IsNotSupported().
func NewNotSupported(err error, msg string) error {
return &notSupported{wrap(err, msg, "")}
}
// IsNotSupported reports whether the error was created with
// NotSupportedf() or NewNotSupported().
func IsNotSupported(err error) bool {
err = Cause(err)
_, ok := err.(*notSupported)
return ok
}
// notValid represents an error when something is not valid.
type notValid struct {
Err
}
// NotValidf returns an error which satisfies IsNotValid().
func NotValidf(format string, args ...interface{}) error {
return &notValid{wrap(nil, format, " not valid", args...)}
}
// NewNotValid returns an error which wraps err and satisfies IsNotValid().
func NewNotValid(err error, msg string) error {
return &notValid{wrap(err, msg, "")}
}
// IsNotValid reports whether the error was created with NotValidf() or
// NewNotValid().
func IsNotValid(err error) bool {
err = Cause(err)
_, ok := err.(*notValid)
return ok
}
// notProvisioned represents an error when something is not yet provisioned.
type notProvisioned struct {
Err
}
// NotProvisionedf returns an error which satisfies IsNotProvisioned().
func NotProvisionedf(format string, args ...interface{}) error {
return &notProvisioned{wrap(nil, format, " not provisioned", args...)}
}
// NewNotProvisioned returns an error which wraps err that satisfies
// IsNotProvisioned().
func NewNotProvisioned(err error, msg string) error {
return &notProvisioned{wrap(err, msg, "")}
}
// IsNotProvisioned reports whether err was created with NotProvisionedf() or
// NewNotProvisioned().
func IsNotProvisioned(err error) bool {
err = Cause(err)
_, ok := err.(*notProvisioned)
return ok
}
// notAssigned represents an error when something is not yet assigned to
// something else.
type notAssigned struct {
Err
}
// NotAssignedf returns an error which satisfies IsNotAssigned().
func NotAssignedf(format string, args ...interface{}) error {
return &notAssigned{wrap(nil, format, " not assigned", args...)}
}
// NewNotAssigned returns an error which wraps err that satisfies
// IsNotAssigned().
func NewNotAssigned(err error, msg string) error {
return &notAssigned{wrap(err, msg, "")}
}
// IsNotAssigned reports whether err was created with NotAssignedf() or
// NewNotAssigned().
func IsNotAssigned(err error) bool {
err = Cause(err)
_, ok := err.(*notAssigned)
return ok
}
// badRequest represents an error when a request has bad parameters.
type badRequest struct {
Err
}
// BadRequestf returns an error which satisfies IsBadRequest().
func BadRequestf(format string, args ...interface{}) error {
return &badRequest{wrap(nil, format, "", args...)}
}
// NewBadRequest returns an error which wraps err that satisfies
// IsBadRequest().
func NewBadRequest(err error, msg string) error {
return &badRequest{wrap(err, msg, "")}
}
// IsBadRequest reports whether err was created with BadRequestf() or
// NewBadRequest().
func IsBadRequest(err error) bool {
err = Cause(err)
_, ok := err.(*badRequest)
return ok
}
// methodNotAllowed represents an error when an HTTP request
// is made with an inappropriate method.
type methodNotAllowed struct {
Err
}
// MethodNotAllowedf returns an error which satisfies IsMethodNotAllowed().
func MethodNotAllowedf(format string, args ...interface{}) error {
return &methodNotAllowed{wrap(nil, format, "", args...)}
}
// NewMethodNotAllowed returns an error which wraps err that satisfies
// IsMethodNotAllowed().
func NewMethodNotAllowed(err error, msg string) error {
return &methodNotAllowed{wrap(err, msg, "")}
}
// IsMethodNotAllowed reports whether err was created with MethodNotAllowedf() or
// NewMethodNotAllowed().
func IsMethodNotAllowed(err error) bool {
err = Cause(err)
_, ok := err.(*methodNotAllowed)
return ok
}
// forbidden represents an error when a request cannot be completed because of
// missing privileges
type forbidden struct {
Err
}
// Forbiddenf returns an error which satistifes IsForbidden()
func Forbiddenf(format string, args ...interface{}) error {
return &forbidden{wrap(nil, format, "", args...)}
}
// NewForbidden returns an error which wraps err that satisfies
// IsForbidden().
func NewForbidden(err error, msg string) error {
return &forbidden{wrap(err, msg, "")}
}
// IsForbidden reports whether err was created with Forbiddenf() or
// NewForbidden().
func IsForbidden(err error) bool {
err = Cause(err)
_, ok := err.(*forbidden)
return ok
}

View File

@ -0,0 +1,330 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"fmt"
"strings"
)
// New is a drop in replacement for the standard library errors module that records
// the location that the error is created.
//
// For example:
// return errors.New("validation failed")
//
func New(message string) error {
err := &Err{message: message}
err.SetLocation(1)
return err
}
// Errorf creates a new annotated error and records the location that the
// error is created. This should be a drop in replacement for fmt.Errorf.
//
// For example:
// return errors.Errorf("validation failed: %s", message)
//
func Errorf(format string, args ...interface{}) error {
err := &Err{message: fmt.Sprintf(format, args...)}
err.SetLocation(1)
return err
}
// Trace adds the location of the Trace call to the stack. The Cause of the
// resulting error is the same as the error parameter. If the other error is
// nil, the result will be nil.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Trace(err)
// }
//
func Trace(other error) error {
if other == nil {
return nil
}
err := &Err{previous: other, cause: Cause(other)}
err.SetLocation(1)
return err
}
// Annotate is used to add extra context to an existing error. The location of
// the Annotate call is recorded with the annotations. The file, line and
// function are also recorded.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Annotate(err, "failed to frombulate")
// }
//
func Annotate(other error, message string) error {
if other == nil {
return nil
}
err := &Err{
previous: other,
cause: Cause(other),
message: message,
}
err.SetLocation(1)
return err
}
// Annotatef is used to add extra context to an existing error. The location of
// the Annotate call is recorded with the annotations. The file, line and
// function are also recorded.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Annotatef(err, "failed to frombulate the %s", arg)
// }
//
func Annotatef(other error, format string, args ...interface{}) error {
if other == nil {
return nil
}
err := &Err{
previous: other,
cause: Cause(other),
message: fmt.Sprintf(format, args...),
}
err.SetLocation(1)
return err
}
// DeferredAnnotatef annotates the given error (when it is not nil) with the given
// format string and arguments (like fmt.Sprintf). If *err is nil, DeferredAnnotatef
// does nothing. This method is used in a defer statement in order to annotate any
// resulting error with the same message.
//
// For example:
//
// defer DeferredAnnotatef(&err, "failed to frombulate the %s", arg)
//
func DeferredAnnotatef(err *error, format string, args ...interface{}) {
if *err == nil {
return
}
newErr := &Err{
message: fmt.Sprintf(format, args...),
cause: Cause(*err),
previous: *err,
}
newErr.SetLocation(1)
*err = newErr
}
// Wrap changes the Cause of the error. The location of the Wrap call is also
// stored in the error stack.
//
// For example:
// if err := SomeFunc(); err != nil {
// newErr := &packageError{"more context", private_value}
// return errors.Wrap(err, newErr)
// }
//
func Wrap(other, newDescriptive error) error {
err := &Err{
previous: other,
cause: newDescriptive,
}
err.SetLocation(1)
return err
}
// Wrapf changes the Cause of the error, and adds an annotation. The location
// of the Wrap call is also stored in the error stack.
//
// For example:
// if err := SomeFunc(); err != nil {
// return errors.Wrapf(err, simpleErrorType, "invalid value %q", value)
// }
//
func Wrapf(other, newDescriptive error, format string, args ...interface{}) error {
err := &Err{
message: fmt.Sprintf(format, args...),
previous: other,
cause: newDescriptive,
}
err.SetLocation(1)
return err
}
// Mask masks the given error with the given format string and arguments (like
// fmt.Sprintf), returning a new error that maintains the error stack, but
// hides the underlying error type. The error string still contains the full
// annotations. If you want to hide the annotations, call Wrap.
func Maskf(other error, format string, args ...interface{}) error {
if other == nil {
return nil
}
err := &Err{
message: fmt.Sprintf(format, args...),
previous: other,
}
err.SetLocation(1)
return err
}
// Mask hides the underlying error type, and records the location of the masking.
func Mask(other error) error {
if other == nil {
return nil
}
err := &Err{
previous: other,
}
err.SetLocation(1)
return err
}
// Cause returns the cause of the given error. This will be either the
// original error, or the result of a Wrap or Mask call.
//
// Cause is the usual way to diagnose errors that may have been wrapped by
// the other errors functions.
func Cause(err error) error {
var diag error
if err, ok := err.(causer); ok {
diag = err.Cause()
}
if diag != nil {
return diag
}
return err
}
type causer interface {
Cause() error
}
type wrapper interface {
// Message returns the top level error message,
// not including the message from the Previous
// error.
Message() string
// Underlying returns the Previous error, or nil
// if there is none.
Underlying() error
}
type locationer interface {
Location() (string, int)
}
var (
_ wrapper = (*Err)(nil)
_ locationer = (*Err)(nil)
_ causer = (*Err)(nil)
)
// Details returns information about the stack of errors wrapped by err, in
// the format:
//
// [{filename:99: error one} {otherfile:55: cause of error one}]
//
// This is a terse alternative to ErrorStack as it returns a single line.
func Details(err error) string {
if err == nil {
return "[]"
}
var s []byte
s = append(s, '[')
for {
s = append(s, '{')
if err, ok := err.(locationer); ok {
file, line := err.Location()
if file != "" {
s = append(s, fmt.Sprintf("%s:%d", file, line)...)
s = append(s, ": "...)
}
}
if cerr, ok := err.(wrapper); ok {
s = append(s, cerr.Message()...)
err = cerr.Underlying()
} else {
s = append(s, err.Error()...)
err = nil
}
s = append(s, '}')
if err == nil {
break
}
s = append(s, ' ')
}
s = append(s, ']')
return string(s)
}
// ErrorStack returns a string representation of the annotated error. If the
// error passed as the parameter is not an annotated error, the result is
// simply the result of the Error() method on that error.
//
// If the error is an annotated error, a multi-line string is returned where
// each line represents one entry in the annotation stack. The full filename
// from the call stack is used in the output.
//
// first error
// github.com/juju/errors/annotation_test.go:193:
// github.com/juju/errors/annotation_test.go:194: annotation
// github.com/juju/errors/annotation_test.go:195:
// github.com/juju/errors/annotation_test.go:196: more context
// github.com/juju/errors/annotation_test.go:197:
func ErrorStack(err error) string {
return strings.Join(errorStack(err), "\n")
}
func errorStack(err error) []string {
if err == nil {
return nil
}
// We want the first error first
var lines []string
for {
var buff []byte
if err, ok := err.(locationer); ok {
file, line := err.Location()
// Strip off the leading GOPATH/src path elements.
file = trimGoPath(file)
if file != "" {
buff = append(buff, fmt.Sprintf("%s:%d", file, line)...)
buff = append(buff, ": "...)
}
}
if cerr, ok := err.(wrapper); ok {
message := cerr.Message()
buff = append(buff, message...)
// If there is a cause for this error, and it is different to the cause
// of the underlying error, then output the error string in the stack trace.
var cause error
if err1, ok := err.(causer); ok {
cause = err1.Cause()
}
err = cerr.Underlying()
if cause != nil && !sameError(Cause(err), cause) {
if message != "" {
buff = append(buff, ": "...)
}
buff = append(buff, cause.Error()...)
}
} else {
buff = append(buff, err.Error()...)
err = nil
}
lines = append(lines, string(buff))
if err == nil {
break
}
}
// reverse the lines to get the original error, which was at the end of
// the list, back to the start.
var result []string
for i := len(lines); i > 0; i-- {
result = append(result, lines[i-1])
}
return result
}

View File

@ -0,0 +1,38 @@
// Copyright 2013, 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package errors
import (
"runtime"
"strings"
)
// prefixSize is used internally to trim the user specific path from the
// front of the returned filenames from the runtime call stack.
var prefixSize int
// goPath is the deduced path based on the location of this file as compiled.
var goPath string
func init() {
_, file, _, ok := runtime.Caller(0)
if file == "?" {
return
}
if ok {
// We know that the end of the file should be:
// github.com/juju/errors/path.go
size := len(file)
suffix := len("github.com/juju/errors/path.go")
goPath = file[:size-suffix]
prefixSize = len(goPath)
}
}
func trimGoPath(filename string) string {
if strings.HasPrefix(filename, goPath) {
return filename[prefixSize:]
}
return filename
}

View File

@ -0,0 +1,187 @@
// Copyright (c) 2012 The Go Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package check
import (
"fmt"
"runtime"
"time"
)
var memStats runtime.MemStats
// testingB is a type passed to Benchmark functions to manage benchmark
// timing and to specify the number of iterations to run.
type timer struct {
start time.Time // Time test or benchmark started
duration time.Duration
N int
bytes int64
timerOn bool
benchTime time.Duration
// The initial states of memStats.Mallocs and memStats.TotalAlloc.
startAllocs uint64
startBytes uint64
// The net total of this test after being run.
netAllocs uint64
netBytes uint64
}
// StartTimer starts timing a test. This function is called automatically
// before a benchmark starts, but it can also used to resume timing after
// a call to StopTimer.
func (c *C) StartTimer() {
if !c.timerOn {
c.start = time.Now()
c.timerOn = true
runtime.ReadMemStats(&memStats)
c.startAllocs = memStats.Mallocs
c.startBytes = memStats.TotalAlloc
}
}
// StopTimer stops timing a test. This can be used to pause the timer
// while performing complex initialization that you don't
// want to measure.
func (c *C) StopTimer() {
if c.timerOn {
c.duration += time.Now().Sub(c.start)
c.timerOn = false
runtime.ReadMemStats(&memStats)
c.netAllocs += memStats.Mallocs - c.startAllocs
c.netBytes += memStats.TotalAlloc - c.startBytes
}
}
// ResetTimer sets the elapsed benchmark time to zero.
// It does not affect whether the timer is running.
func (c *C) ResetTimer() {
if c.timerOn {
c.start = time.Now()
runtime.ReadMemStats(&memStats)
c.startAllocs = memStats.Mallocs
c.startBytes = memStats.TotalAlloc
}
c.duration = 0
c.netAllocs = 0
c.netBytes = 0
}
// SetBytes informs the number of bytes that the benchmark processes
// on each iteration. If this is called in a benchmark it will also
// report MB/s.
func (c *C) SetBytes(n int64) {
c.bytes = n
}
func (c *C) nsPerOp() int64 {
if c.N <= 0 {
return 0
}
return c.duration.Nanoseconds() / int64(c.N)
}
func (c *C) mbPerSec() float64 {
if c.bytes <= 0 || c.duration <= 0 || c.N <= 0 {
return 0
}
return (float64(c.bytes) * float64(c.N) / 1e6) / c.duration.Seconds()
}
func (c *C) timerString() string {
if c.N <= 0 {
return fmt.Sprintf("%3.3fs", float64(c.duration.Nanoseconds())/1e9)
}
mbs := c.mbPerSec()
mb := ""
if mbs != 0 {
mb = fmt.Sprintf("\t%7.2f MB/s", mbs)
}
nsop := c.nsPerOp()
ns := fmt.Sprintf("%10d ns/op", nsop)
if c.N > 0 && nsop < 100 {
// The format specifiers here make sure that
// the ones digits line up for all three possible formats.
if nsop < 10 {
ns = fmt.Sprintf("%13.2f ns/op", float64(c.duration.Nanoseconds())/float64(c.N))
} else {
ns = fmt.Sprintf("%12.1f ns/op", float64(c.duration.Nanoseconds())/float64(c.N))
}
}
memStats := ""
if c.benchMem {
allocedBytes := fmt.Sprintf("%8d B/op", int64(c.netBytes)/int64(c.N))
allocs := fmt.Sprintf("%8d allocs/op", int64(c.netAllocs)/int64(c.N))
memStats = fmt.Sprintf("\t%s\t%s", allocedBytes, allocs)
}
return fmt.Sprintf("%8d\t%s%s%s", c.N, ns, mb, memStats)
}
func min(x, y int) int {
if x > y {
return y
}
return x
}
func max(x, y int) int {
if x < y {
return y
}
return x
}
// roundDown10 rounds a number down to the nearest power of 10.
func roundDown10(n int) int {
var tens = 0
// tens = floor(log_10(n))
for n > 10 {
n = n / 10
tens++
}
// result = 10^tens
result := 1
for i := 0; i < tens; i++ {
result *= 10
}
return result
}
// roundUp rounds x up to a number of the form [1eX, 2eX, 5eX].
func roundUp(n int) int {
base := roundDown10(n)
if n < (2 * base) {
return 2 * base
}
if n < (5 * base) {
return 5 * base
}
return 10 * base
}

View File

@ -0,0 +1,980 @@
// Package check is a rich testing extension for Go's testing package.
//
// For details about the project, see:
//
// http://labix.org/gocheck
//
package check
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// -----------------------------------------------------------------------
// Internal type which deals with suite method calling.
const (
fixtureKd = iota
testKd
)
type funcKind int
const (
succeededSt = iota
failedSt
skippedSt
panickedSt
fixturePanickedSt
missedSt
)
type funcStatus uint32
// A method value can't reach its own Method structure.
type methodType struct {
reflect.Value
Info reflect.Method
}
func newMethod(receiver reflect.Value, i int) *methodType {
return &methodType{receiver.Method(i), receiver.Type().Method(i)}
}
func (method *methodType) PC() uintptr {
return method.Info.Func.Pointer()
}
func (method *methodType) suiteName() string {
t := method.Info.Type.In(0)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t.Name()
}
func (method *methodType) String() string {
return method.suiteName() + "." + method.Info.Name
}
func (method *methodType) matches(re *regexp.Regexp) bool {
return (re.MatchString(method.Info.Name) ||
re.MatchString(method.suiteName()) ||
re.MatchString(method.String()))
}
type C struct {
method *methodType
kind funcKind
testName string
_status funcStatus
logb *logger
logw io.Writer
done chan *C
parallel chan *C
reason string
mustFail bool
tempDir *tempDir
benchMem bool
startTime time.Time
timer
}
func (c *C) status() funcStatus {
return funcStatus(atomic.LoadUint32((*uint32)(&c._status)))
}
func (c *C) setStatus(s funcStatus) {
atomic.StoreUint32((*uint32)(&c._status), uint32(s))
}
func (c *C) stopNow() {
runtime.Goexit()
}
// logger is a concurrency safe byte.Buffer
type logger struct {
sync.Mutex
writer bytes.Buffer
}
func (l *logger) Write(buf []byte) (int, error) {
l.Lock()
defer l.Unlock()
return l.writer.Write(buf)
}
func (l *logger) WriteTo(w io.Writer) (int64, error) {
l.Lock()
defer l.Unlock()
return l.writer.WriteTo(w)
}
func (l *logger) String() string {
l.Lock()
defer l.Unlock()
return l.writer.String()
}
// -----------------------------------------------------------------------
// Handling of temporary files and directories.
type tempDir struct {
sync.Mutex
path string
counter int
}
func (td *tempDir) newPath() string {
td.Lock()
defer td.Unlock()
if td.path == "" {
var err error
for i := 0; i != 100; i++ {
path := fmt.Sprintf("%s%ccheck-%d", os.TempDir(), os.PathSeparator, rand.Int())
if err = os.Mkdir(path, 0700); err == nil {
td.path = path
break
}
}
if td.path == "" {
panic("Couldn't create temporary directory: " + err.Error())
}
}
result := filepath.Join(td.path, strconv.Itoa(td.counter))
td.counter += 1
return result
}
func (td *tempDir) removeAll() {
td.Lock()
defer td.Unlock()
if td.path != "" {
err := os.RemoveAll(td.path)
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Error cleaning up temporaries: "+err.Error())
}
}
}
// Create a new temporary directory which is automatically removed after
// the suite finishes running.
func (c *C) MkDir() string {
path := c.tempDir.newPath()
if err := os.Mkdir(path, 0700); err != nil {
panic(fmt.Sprintf("Couldn't create temporary directory %s: %s", path, err.Error()))
}
return path
}
// -----------------------------------------------------------------------
// Low-level logging functions.
func (c *C) log(args ...interface{}) {
c.writeLog([]byte(fmt.Sprint(args...) + "\n"))
}
func (c *C) logf(format string, args ...interface{}) {
c.writeLog([]byte(fmt.Sprintf(format+"\n", args...)))
}
func (c *C) logNewLine() {
c.writeLog([]byte{'\n'})
}
func (c *C) writeLog(buf []byte) {
c.logb.Write(buf)
if c.logw != nil {
c.logw.Write(buf)
}
}
func hasStringOrError(x interface{}) (ok bool) {
_, ok = x.(fmt.Stringer)
if ok {
return
}
_, ok = x.(error)
return
}
func (c *C) logValue(label string, value interface{}) {
if label == "" {
if hasStringOrError(value) {
c.logf("... %#v (%q)", value, value)
} else {
c.logf("... %#v", value)
}
} else if value == nil {
c.logf("... %s = nil", label)
} else {
if hasStringOrError(value) {
fv := fmt.Sprintf("%#v", value)
qv := fmt.Sprintf("%q", value)
if fv != qv {
c.logf("... %s %s = %s (%s)", label, reflect.TypeOf(value), fv, qv)
return
}
}
if s, ok := value.(string); ok && isMultiLine(s) {
c.logf(`... %s %s = "" +`, label, reflect.TypeOf(value))
c.logMultiLine(s)
} else {
c.logf("... %s %s = %#v", label, reflect.TypeOf(value), value)
}
}
}
func (c *C) logMultiLine(s string) {
b := make([]byte, 0, len(s)*2)
i := 0
n := len(s)
for i < n {
j := i + 1
for j < n && s[j-1] != '\n' {
j++
}
b = append(b, "... "...)
b = strconv.AppendQuote(b, s[i:j])
if j < n {
b = append(b, " +"...)
}
b = append(b, '\n')
i = j
}
c.writeLog(b)
}
func isMultiLine(s string) bool {
for i := 0; i+1 < len(s); i++ {
if s[i] == '\n' {
return true
}
}
return false
}
func (c *C) logString(issue string) {
c.log("... ", issue)
}
func (c *C) logCaller(skip int) {
// This is a bit heavier than it ought to be.
skip += 1 // Our own frame.
pc, callerFile, callerLine, ok := runtime.Caller(skip)
if !ok {
return
}
var testFile string
var testLine int
testFunc := runtime.FuncForPC(c.method.PC())
if runtime.FuncForPC(pc) != testFunc {
for {
skip += 1
if pc, file, line, ok := runtime.Caller(skip); ok {
// Note that the test line may be different on
// distinct calls for the same test. Showing
// the "internal" line is helpful when debugging.
if runtime.FuncForPC(pc) == testFunc {
testFile, testLine = file, line
break
}
} else {
break
}
}
}
if testFile != "" && (testFile != callerFile || testLine != callerLine) {
c.logCode(testFile, testLine)
}
c.logCode(callerFile, callerLine)
}
func (c *C) logCode(path string, line int) {
c.logf("%s:%d:", nicePath(path), line)
code, err := printLine(path, line)
if code == "" {
code = "..." // XXX Open the file and take the raw line.
if err != nil {
code += err.Error()
}
}
c.log(indent(code, " "))
}
var valueGo = filepath.Join("reflect", "value.go")
var asmGo = filepath.Join("runtime", "asm_")
func (c *C) logPanic(skip int, value interface{}) {
skip++ // Our own frame.
initialSkip := skip
for ; ; skip++ {
if pc, file, line, ok := runtime.Caller(skip); ok {
if skip == initialSkip {
c.logf("... Panic: %s (PC=0x%X)\n", value, pc)
}
name := niceFuncName(pc)
path := nicePath(file)
if strings.Contains(path, "/gopkg.in/check.v") {
continue
}
if name == "Value.call" && strings.HasSuffix(path, valueGo) {
continue
}
if (name == "call16" || name == "call32") && strings.Contains(path, asmGo) {
continue
}
c.logf("%s:%d\n in %s", nicePath(file), line, name)
} else {
break
}
}
}
func (c *C) logSoftPanic(issue string) {
c.log("... Panic: ", issue)
}
func (c *C) logArgPanic(method *methodType, expectedType string) {
c.logf("... Panic: %s argument should be %s",
niceFuncName(method.PC()), expectedType)
}
// -----------------------------------------------------------------------
// Some simple formatting helpers.
var initWD, initWDErr = os.Getwd()
func init() {
if initWDErr == nil {
initWD = strings.Replace(initWD, "\\", "/", -1) + "/"
}
}
func nicePath(path string) string {
if initWDErr == nil {
if strings.HasPrefix(path, initWD) {
return path[len(initWD):]
}
}
return path
}
func niceFuncPath(pc uintptr) string {
function := runtime.FuncForPC(pc)
if function != nil {
filename, line := function.FileLine(pc)
return fmt.Sprintf("%s:%d", nicePath(filename), line)
}
return "<unknown path>"
}
func niceFuncName(pc uintptr) string {
function := runtime.FuncForPC(pc)
if function != nil {
name := path.Base(function.Name())
if i := strings.Index(name, "."); i > 0 {
name = name[i+1:]
}
if strings.HasPrefix(name, "(*") {
if i := strings.Index(name, ")"); i > 0 {
name = name[2:i] + name[i+1:]
}
}
if i := strings.LastIndex(name, ".*"); i != -1 {
name = name[:i] + "." + name[i+2:]
}
if i := strings.LastIndex(name, "·"); i != -1 {
name = name[:i] + "." + name[i+2:]
}
return name
}
return "<unknown function>"
}
// -----------------------------------------------------------------------
// Result tracker to aggregate call results.
type Result struct {
Succeeded int
Failed int
Skipped int
Panicked int
FixturePanicked int
ExpectedFailures int
Missed int // Not even tried to run, related to a panic in the fixture.
RunError error // Houston, we've got a problem.
WorkDir string // If KeepWorkDir is true
}
type resultTracker struct {
result Result
_lastWasProblem bool
_waiting int
_missed int
_expectChan chan *C
_doneChan chan *C
_stopChan chan bool
}
func newResultTracker() *resultTracker {
return &resultTracker{_expectChan: make(chan *C), // Synchronous
_doneChan: make(chan *C, 32), // Asynchronous
_stopChan: make(chan bool)} // Synchronous
}
func (tracker *resultTracker) start() {
go tracker._loopRoutine()
}
func (tracker *resultTracker) waitAndStop() {
<-tracker._stopChan
}
func (tracker *resultTracker) expectCall(c *C) {
tracker._expectChan <- c
}
func (tracker *resultTracker) callDone(c *C) {
tracker._doneChan <- c
}
func (tracker *resultTracker) _loopRoutine() {
for {
var c *C
if tracker._waiting > 0 {
// Calls still running. Can't stop.
select {
// XXX Reindent this (not now to make diff clear)
case c = <-tracker._expectChan:
tracker._waiting += 1
case c = <-tracker._doneChan:
tracker._waiting -= 1
switch c.status() {
case succeededSt:
if c.kind == testKd {
if c.mustFail {
tracker.result.ExpectedFailures++
} else {
tracker.result.Succeeded++
}
}
case failedSt:
tracker.result.Failed++
case panickedSt:
if c.kind == fixtureKd {
tracker.result.FixturePanicked++
} else {
tracker.result.Panicked++
}
case fixturePanickedSt:
// Track it as missed, since the panic
// was on the fixture, not on the test.
tracker.result.Missed++
case missedSt:
tracker.result.Missed++
case skippedSt:
if c.kind == testKd {
tracker.result.Skipped++
}
}
}
} else {
// No calls. Can stop, but no done calls here.
select {
case tracker._stopChan <- true:
return
case c = <-tracker._expectChan:
tracker._waiting += 1
case c = <-tracker._doneChan:
panic("Tracker got an unexpected done call.")
}
}
}
}
// -----------------------------------------------------------------------
// The underlying suite runner.
type suiteRunner struct {
suite interface{}
setUpSuite, tearDownSuite *methodType
setUpTest, tearDownTest *methodType
tests []*methodType
tracker *resultTracker
tempDir *tempDir
keepDir bool
output *outputWriter
reportedProblemLast bool
benchTime time.Duration
benchMem bool
}
type RunConf struct {
Output io.Writer
Stream bool
Verbose bool
Filter string
Benchmark bool
BenchmarkTime time.Duration // Defaults to 1 second
BenchmarkMem bool
KeepWorkDir bool
Exclude string
}
// Create a new suiteRunner able to run all methods in the given suite.
func newSuiteRunner(suite interface{}, runConf *RunConf) *suiteRunner {
var conf RunConf
if runConf != nil {
conf = *runConf
}
if conf.Output == nil {
conf.Output = os.Stdout
}
if conf.Benchmark {
conf.Verbose = true
}
suiteType := reflect.TypeOf(suite)
suiteNumMethods := suiteType.NumMethod()
suiteValue := reflect.ValueOf(suite)
runner := &suiteRunner{
suite: suite,
output: newOutputWriter(conf.Output, conf.Stream, conf.Verbose),
tracker: newResultTracker(),
benchTime: conf.BenchmarkTime,
benchMem: conf.BenchmarkMem,
tempDir: &tempDir{},
keepDir: conf.KeepWorkDir,
tests: make([]*methodType, 0, suiteNumMethods),
}
if runner.benchTime == 0 {
runner.benchTime = 1 * time.Second
}
var filterRegexp *regexp.Regexp
if conf.Filter != "" {
if regexp, err := regexp.Compile(conf.Filter); err != nil {
msg := "Bad filter expression: " + err.Error()
runner.tracker.result.RunError = errors.New(msg)
return runner
} else {
filterRegexp = regexp
}
}
var excludeRegexp *regexp.Regexp
if conf.Exclude != "" {
if regexp, err := regexp.Compile(conf.Exclude); err != nil {
msg := "Bad exclude expression: " + err.Error()
runner.tracker.result.RunError = errors.New(msg)
return runner
} else {
excludeRegexp = regexp
}
}
for i := 0; i != suiteNumMethods; i++ {
method := newMethod(suiteValue, i)
switch method.Info.Name {
case "SetUpSuite":
runner.setUpSuite = method
case "TearDownSuite":
runner.tearDownSuite = method
case "SetUpTest":
runner.setUpTest = method
case "TearDownTest":
runner.tearDownTest = method
default:
prefix := "Test"
if conf.Benchmark {
prefix = "Benchmark"
}
if !strings.HasPrefix(method.Info.Name, prefix) {
continue
}
if filterRegexp == nil || method.matches(filterRegexp) {
if excludeRegexp == nil || !method.matches(excludeRegexp) {
runner.tests = append(runner.tests, method)
}
}
}
}
return runner
}
// Run all methods in the given suite.
func (runner *suiteRunner) run() *Result {
if runner.tracker.result.RunError == nil && len(runner.tests) > 0 {
runner.tracker.start()
if runner.checkFixtureArgs() {
c := runner.runFixture(runner.setUpSuite, "", nil)
if c == nil || c.status() == succeededSt {
var delayedC []*C
for i := 0; i != len(runner.tests); i++ {
c := runner.forkTest(runner.tests[i])
select {
case <-c.done:
case <-c.parallel:
delayedC = append(delayedC, c)
}
if c.status() == fixturePanickedSt {
runner.skipTests(missedSt, runner.tests[i+1:])
break
}
}
// Wait those parallel tests finish.
for _, delayed := range delayedC {
<-delayed.done
}
} else if c != nil && c.status() == skippedSt {
runner.skipTests(skippedSt, runner.tests)
} else {
runner.skipTests(missedSt, runner.tests)
}
runner.runFixture(runner.tearDownSuite, "", nil)
} else {
runner.skipTests(missedSt, runner.tests)
}
runner.tracker.waitAndStop()
if runner.keepDir {
runner.tracker.result.WorkDir = runner.tempDir.path
} else {
runner.tempDir.removeAll()
}
}
return &runner.tracker.result
}
// Create a call object with the given suite method, and fork a
// goroutine with the provided dispatcher for running it.
func (runner *suiteRunner) forkCall(method *methodType, kind funcKind, testName string, logb *logger, dispatcher func(c *C)) *C {
var logw io.Writer
if runner.output.Stream {
logw = runner.output
}
if logb == nil {
logb = new(logger)
}
c := &C{
method: method,
kind: kind,
testName: testName,
logb: logb,
logw: logw,
tempDir: runner.tempDir,
done: make(chan *C, 1),
parallel: make(chan *C, 1),
timer: timer{benchTime: runner.benchTime},
startTime: time.Now(),
benchMem: runner.benchMem,
}
runner.tracker.expectCall(c)
go (func() {
runner.reportCallStarted(c)
defer runner.callDone(c)
dispatcher(c)
})()
return c
}
// Same as forkCall(), but wait for call to finish before returning.
func (runner *suiteRunner) runFunc(method *methodType, kind funcKind, testName string, logb *logger, dispatcher func(c *C)) *C {
c := runner.forkCall(method, kind, testName, logb, dispatcher)
<-c.done
return c
}
// Handle a finished call. If there were any panics, update the call status
// accordingly. Then, mark the call as done and report to the tracker.
func (runner *suiteRunner) callDone(c *C) {
value := recover()
if value != nil {
switch v := value.(type) {
case *fixturePanic:
if v.status == skippedSt {
c.setStatus(skippedSt)
} else {
c.logSoftPanic("Fixture has panicked (see related PANIC)")
c.setStatus(fixturePanickedSt)
}
default:
c.logPanic(1, value)
c.setStatus(panickedSt)
}
}
if c.mustFail {
switch c.status() {
case failedSt:
c.setStatus(succeededSt)
case succeededSt:
c.setStatus(failedSt)
c.logString("Error: Test succeeded, but was expected to fail")
c.logString("Reason: " + c.reason)
}
}
runner.reportCallDone(c)
c.done <- c
}
// Runs a fixture call synchronously. The fixture will still be run in a
// goroutine like all suite methods, but this method will not return
// while the fixture goroutine is not done, because the fixture must be
// run in a desired order.
func (runner *suiteRunner) runFixture(method *methodType, testName string, logb *logger) *C {
if method != nil {
c := runner.runFunc(method, fixtureKd, testName, logb, func(c *C) {
c.ResetTimer()
c.StartTimer()
defer c.StopTimer()
c.method.Call([]reflect.Value{reflect.ValueOf(c)})
})
return c
}
return nil
}
// Run the fixture method with runFixture(), but panic with a fixturePanic{}
// in case the fixture method panics. This makes it easier to track the
// fixture panic together with other call panics within forkTest().
func (runner *suiteRunner) runFixtureWithPanic(method *methodType, testName string, logb *logger, skipped *bool) *C {
if skipped != nil && *skipped {
return nil
}
c := runner.runFixture(method, testName, logb)
if c != nil && c.status() != succeededSt {
if skipped != nil {
*skipped = c.status() == skippedSt
}
panic(&fixturePanic{c.status(), method})
}
return c
}
type fixturePanic struct {
status funcStatus
method *methodType
}
// Run the suite test method, together with the test-specific fixture,
// asynchronously.
func (runner *suiteRunner) forkTest(method *methodType) *C {
testName := method.String()
return runner.forkCall(method, testKd, testName, nil, func(c *C) {
var skipped bool
defer runner.runFixtureWithPanic(runner.tearDownTest, testName, nil, &skipped)
defer c.StopTimer()
benchN := 1
for {
runner.runFixtureWithPanic(runner.setUpTest, testName, c.logb, &skipped)
mt := c.method.Type()
if mt.NumIn() != 1 || mt.In(0) != reflect.TypeOf(c) {
// Rather than a plain panic, provide a more helpful message when
// the argument type is incorrect.
c.setStatus(panickedSt)
c.logArgPanic(c.method, "*check.C")
return
}
if strings.HasPrefix(c.method.Info.Name, "Test") {
c.ResetTimer()
c.StartTimer()
c.method.Call([]reflect.Value{reflect.ValueOf(c)})
return
}
if !strings.HasPrefix(c.method.Info.Name, "Benchmark") {
panic("unexpected method prefix: " + c.method.Info.Name)
}
runtime.GC()
c.N = benchN
c.ResetTimer()
c.StartTimer()
c.method.Call([]reflect.Value{reflect.ValueOf(c)})
c.StopTimer()
if c.status() != succeededSt || c.duration >= c.benchTime || benchN >= 1e9 {
return
}
perOpN := int(1e9)
if c.nsPerOp() != 0 {
perOpN = int(c.benchTime.Nanoseconds() / c.nsPerOp())
}
// Logic taken from the stock testing package:
// - Run more iterations than we think we'll need for a second (1.5x).
// - Don't grow too fast in case we had timing errors previously.
// - Be sure to run at least one more than last time.
benchN = max(min(perOpN+perOpN/2, 100*benchN), benchN+1)
benchN = roundUp(benchN)
skipped = true // Don't run the deferred one if this panics.
runner.runFixtureWithPanic(runner.tearDownTest, testName, nil, nil)
skipped = false
}
})
}
// Same as forkTest(), but wait for the test to finish before returning.
func (runner *suiteRunner) runTest(method *methodType) *C {
c := runner.forkTest(method)
<-c.done
return c
}
// Helper to mark tests as skipped or missed. A bit heavy for what
// it does, but it enables homogeneous handling of tracking, including
// nice verbose output.
func (runner *suiteRunner) skipTests(status funcStatus, methods []*methodType) {
for _, method := range methods {
runner.runFunc(method, testKd, "", nil, func(c *C) {
c.setStatus(status)
})
}
}
// Verify if the fixture arguments are *check.C. In case of errors,
// log the error as a panic in the fixture method call, and return false.
func (runner *suiteRunner) checkFixtureArgs() bool {
succeeded := true
argType := reflect.TypeOf(&C{})
for _, method := range []*methodType{runner.setUpSuite, runner.tearDownSuite, runner.setUpTest, runner.tearDownTest} {
if method != nil {
mt := method.Type()
if mt.NumIn() != 1 || mt.In(0) != argType {
succeeded = false
runner.runFunc(method, fixtureKd, "", nil, func(c *C) {
c.logArgPanic(method, "*check.C")
c.setStatus(panickedSt)
})
}
}
}
return succeeded
}
func (runner *suiteRunner) reportCallStarted(c *C) {
runner.output.WriteCallStarted("START", c)
}
func (runner *suiteRunner) reportCallDone(c *C) {
runner.tracker.callDone(c)
switch c.status() {
case succeededSt:
if c.mustFail {
runner.output.WriteCallSuccess("FAIL EXPECTED", c)
} else {
runner.output.WriteCallSuccess("PASS", c)
}
case skippedSt:
runner.output.WriteCallSuccess("SKIP", c)
case failedSt:
runner.output.WriteCallProblem("FAIL", c)
case panickedSt:
runner.output.WriteCallProblem("PANIC", c)
case fixturePanickedSt:
// That's a testKd call reporting that its fixture
// has panicked. The fixture call which caused the
// panic itself was tracked above. We'll report to
// aid debugging.
runner.output.WriteCallProblem("PANIC", c)
case missedSt:
runner.output.WriteCallSuccess("MISS", c)
}
}
// -----------------------------------------------------------------------
// Output writer manages atomic output writing according to settings.
type outputWriter struct {
m sync.Mutex
writer io.Writer
wroteCallProblemLast bool
Stream bool
Verbose bool
}
func newOutputWriter(writer io.Writer, stream, verbose bool) *outputWriter {
return &outputWriter{writer: writer, Stream: stream, Verbose: verbose}
}
func (ow *outputWriter) Write(content []byte) (n int, err error) {
ow.m.Lock()
n, err = ow.writer.Write(content)
ow.m.Unlock()
return
}
func (ow *outputWriter) WriteCallStarted(label string, c *C) {
if ow.Stream {
header := renderCallHeader(label, c, "", "\n")
ow.m.Lock()
ow.writer.Write([]byte(header))
ow.m.Unlock()
}
}
func (ow *outputWriter) WriteCallProblem(label string, c *C) {
var prefix string
if !ow.Stream {
prefix = "\n-----------------------------------" +
"-----------------------------------\n"
}
header := renderCallHeader(label, c, prefix, "\n\n")
ow.m.Lock()
ow.wroteCallProblemLast = true
ow.writer.Write([]byte(header))
if !ow.Stream {
c.logb.WriteTo(ow.writer)
}
ow.m.Unlock()
}
func (ow *outputWriter) WriteCallSuccess(label string, c *C) {
if ow.Stream || (ow.Verbose && c.kind == testKd) {
// TODO Use a buffer here.
var suffix string
if c.reason != "" {
suffix = " (" + c.reason + ")"
}
if c.status() == succeededSt {
suffix += "\t" + c.timerString()
}
suffix += "\n"
if ow.Stream {
suffix += "\n"
}
header := renderCallHeader(label, c, "", suffix)
ow.m.Lock()
// Resist temptation of using line as prefix above due to race.
if !ow.Stream && ow.wroteCallProblemLast {
header = "\n-----------------------------------" +
"-----------------------------------\n" +
header
}
ow.wroteCallProblemLast = false
ow.writer.Write([]byte(header))
ow.m.Unlock()
}
}
func renderCallHeader(label string, c *C, prefix, suffix string) string {
pc := c.method.PC()
return fmt.Sprintf("%s%s: %s: %s%s", prefix, label, niceFuncPath(pc),
niceFuncName(pc), suffix)
}

View File

@ -0,0 +1,458 @@
package check
import (
"fmt"
"reflect"
"regexp"
)
// -----------------------------------------------------------------------
// CommentInterface and Commentf helper, to attach extra information to checks.
type comment struct {
format string
args []interface{}
}
// Commentf returns an infomational value to use with Assert or Check calls.
// If the checker test fails, the provided arguments will be passed to
// fmt.Sprintf, and will be presented next to the logged failure.
//
// For example:
//
// c.Assert(v, Equals, 42, Commentf("Iteration #%d failed.", i))
//
// Note that if the comment is constant, a better option is to
// simply use a normal comment right above or next to the line, as
// it will also get printed with any errors:
//
// c.Assert(l, Equals, 8192) // Ensure buffer size is correct (bug #123)
//
func Commentf(format string, args ...interface{}) CommentInterface {
return &comment{format, args}
}
// CommentInterface must be implemented by types that attach extra
// information to failed checks. See the Commentf function for details.
type CommentInterface interface {
CheckCommentString() string
}
func (c *comment) CheckCommentString() string {
return fmt.Sprintf(c.format, c.args...)
}
// -----------------------------------------------------------------------
// The Checker interface.
// The Checker interface must be provided by checkers used with
// the Assert and Check verification methods.
type Checker interface {
Info() *CheckerInfo
Check(params []interface{}, names []string) (result bool, error string)
}
// See the Checker interface.
type CheckerInfo struct {
Name string
Params []string
}
func (info *CheckerInfo) Info() *CheckerInfo {
return info
}
// -----------------------------------------------------------------------
// Not checker logic inverter.
// The Not checker inverts the logic of the provided checker. The
// resulting checker will succeed where the original one failed, and
// vice-versa.
//
// For example:
//
// c.Assert(a, Not(Equals), b)
//
func Not(checker Checker) Checker {
return &notChecker{checker}
}
type notChecker struct {
sub Checker
}
func (checker *notChecker) Info() *CheckerInfo {
info := *checker.sub.Info()
info.Name = "Not(" + info.Name + ")"
return &info
}
func (checker *notChecker) Check(params []interface{}, names []string) (result bool, error string) {
result, error = checker.sub.Check(params, names)
result = !result
return
}
// -----------------------------------------------------------------------
// IsNil checker.
type isNilChecker struct {
*CheckerInfo
}
// The IsNil checker tests whether the obtained value is nil.
//
// For example:
//
// c.Assert(err, IsNil)
//
var IsNil Checker = &isNilChecker{
&CheckerInfo{Name: "IsNil", Params: []string{"value"}},
}
func (checker *isNilChecker) Check(params []interface{}, names []string) (result bool, error string) {
return isNil(params[0]), ""
}
func isNil(obtained interface{}) (result bool) {
if obtained == nil {
result = true
} else {
switch v := reflect.ValueOf(obtained); v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return v.IsNil()
}
}
return
}
// -----------------------------------------------------------------------
// NotNil checker. Alias for Not(IsNil), since it's so common.
type notNilChecker struct {
*CheckerInfo
}
// The NotNil checker verifies that the obtained value is not nil.
//
// For example:
//
// c.Assert(iface, NotNil)
//
// This is an alias for Not(IsNil), made available since it's a
// fairly common check.
//
var NotNil Checker = &notNilChecker{
&CheckerInfo{Name: "NotNil", Params: []string{"value"}},
}
func (checker *notNilChecker) Check(params []interface{}, names []string) (result bool, error string) {
return !isNil(params[0]), ""
}
// -----------------------------------------------------------------------
// Equals checker.
type equalsChecker struct {
*CheckerInfo
}
// The Equals checker verifies that the obtained value is equal to
// the expected value, according to usual Go semantics for ==.
//
// For example:
//
// c.Assert(value, Equals, 42)
//
var Equals Checker = &equalsChecker{
&CheckerInfo{Name: "Equals", Params: []string{"obtained", "expected"}},
}
func (checker *equalsChecker) Check(params []interface{}, names []string) (result bool, error string) {
defer func() {
if v := recover(); v != nil {
result = false
error = fmt.Sprint(v)
}
}()
return params[0] == params[1], ""
}
// -----------------------------------------------------------------------
// DeepEquals checker.
type deepEqualsChecker struct {
*CheckerInfo
}
// The DeepEquals checker verifies that the obtained value is deep-equal to
// the expected value. The check will work correctly even when facing
// slices, interfaces, and values of different types (which always fail
// the test).
//
// For example:
//
// c.Assert(value, DeepEquals, 42)
// c.Assert(array, DeepEquals, []string{"hi", "there"})
//
var DeepEquals Checker = &deepEqualsChecker{
&CheckerInfo{Name: "DeepEquals", Params: []string{"obtained", "expected"}},
}
func (checker *deepEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
return reflect.DeepEqual(params[0], params[1]), ""
}
// -----------------------------------------------------------------------
// HasLen checker.
type hasLenChecker struct {
*CheckerInfo
}
// The HasLen checker verifies that the obtained value has the
// provided length. In many cases this is superior to using Equals
// in conjunction with the len function because in case the check
// fails the value itself will be printed, instead of its length,
// providing more details for figuring the problem.
//
// For example:
//
// c.Assert(list, HasLen, 5)
//
var HasLen Checker = &hasLenChecker{
&CheckerInfo{Name: "HasLen", Params: []string{"obtained", "n"}},
}
func (checker *hasLenChecker) Check(params []interface{}, names []string) (result bool, error string) {
n, ok := params[1].(int)
if !ok {
return false, "n must be an int"
}
value := reflect.ValueOf(params[0])
switch value.Kind() {
case reflect.Map, reflect.Array, reflect.Slice, reflect.Chan, reflect.String:
default:
return false, "obtained value type has no length"
}
return value.Len() == n, ""
}
// -----------------------------------------------------------------------
// ErrorMatches checker.
type errorMatchesChecker struct {
*CheckerInfo
}
// The ErrorMatches checker verifies that the error value
// is non nil and matches the regular expression provided.
//
// For example:
//
// c.Assert(err, ErrorMatches, "perm.*denied")
//
var ErrorMatches Checker = errorMatchesChecker{
&CheckerInfo{Name: "ErrorMatches", Params: []string{"value", "regex"}},
}
func (checker errorMatchesChecker) Check(params []interface{}, names []string) (result bool, errStr string) {
if params[0] == nil {
return false, "Error value is nil"
}
err, ok := params[0].(error)
if !ok {
return false, "Value is not an error"
}
params[0] = err.Error()
names[0] = "error"
return matches(params[0], params[1])
}
// -----------------------------------------------------------------------
// Matches checker.
type matchesChecker struct {
*CheckerInfo
}
// The Matches checker verifies that the string provided as the obtained
// value (or the string resulting from obtained.String()) matches the
// regular expression provided.
//
// For example:
//
// c.Assert(err, Matches, "perm.*denied")
//
var Matches Checker = &matchesChecker{
&CheckerInfo{Name: "Matches", Params: []string{"value", "regex"}},
}
func (checker *matchesChecker) Check(params []interface{}, names []string) (result bool, error string) {
return matches(params[0], params[1])
}
func matches(value, regex interface{}) (result bool, error string) {
reStr, ok := regex.(string)
if !ok {
return false, "Regex must be a string"
}
valueStr, valueIsStr := value.(string)
if !valueIsStr {
if valueWithStr, valueHasStr := value.(fmt.Stringer); valueHasStr {
valueStr, valueIsStr = valueWithStr.String(), true
}
}
if valueIsStr {
matches, err := regexp.MatchString("^"+reStr+"$", valueStr)
if err != nil {
return false, "Can't compile regex: " + err.Error()
}
return matches, ""
}
return false, "Obtained value is not a string and has no .String()"
}
// -----------------------------------------------------------------------
// Panics checker.
type panicsChecker struct {
*CheckerInfo
}
// The Panics checker verifies that calling the provided zero-argument
// function will cause a panic which is deep-equal to the provided value.
//
// For example:
//
// c.Assert(func() { f(1, 2) }, Panics, &SomeErrorType{"BOOM"}).
//
//
var Panics Checker = &panicsChecker{
&CheckerInfo{Name: "Panics", Params: []string{"function", "expected"}},
}
func (checker *panicsChecker) Check(params []interface{}, names []string) (result bool, error string) {
f := reflect.ValueOf(params[0])
if f.Kind() != reflect.Func || f.Type().NumIn() != 0 {
return false, "Function must take zero arguments"
}
defer func() {
// If the function has not panicked, then don't do the check.
if error != "" {
return
}
params[0] = recover()
names[0] = "panic"
result = reflect.DeepEqual(params[0], params[1])
}()
f.Call(nil)
return false, "Function has not panicked"
}
type panicMatchesChecker struct {
*CheckerInfo
}
// The PanicMatches checker verifies that calling the provided zero-argument
// function will cause a panic with an error value matching
// the regular expression provided.
//
// For example:
//
// c.Assert(func() { f(1, 2) }, PanicMatches, `open.*: no such file or directory`).
//
//
var PanicMatches Checker = &panicMatchesChecker{
&CheckerInfo{Name: "PanicMatches", Params: []string{"function", "expected"}},
}
func (checker *panicMatchesChecker) Check(params []interface{}, names []string) (result bool, errmsg string) {
f := reflect.ValueOf(params[0])
if f.Kind() != reflect.Func || f.Type().NumIn() != 0 {
return false, "Function must take zero arguments"
}
defer func() {
// If the function has not panicked, then don't do the check.
if errmsg != "" {
return
}
obtained := recover()
names[0] = "panic"
if e, ok := obtained.(error); ok {
params[0] = e.Error()
} else if _, ok := obtained.(string); ok {
params[0] = obtained
} else {
errmsg = "Panic value is not a string or an error"
return
}
result, errmsg = matches(params[0], params[1])
}()
f.Call(nil)
return false, "Function has not panicked"
}
// -----------------------------------------------------------------------
// FitsTypeOf checker.
type fitsTypeChecker struct {
*CheckerInfo
}
// The FitsTypeOf checker verifies that the obtained value is
// assignable to a variable with the same type as the provided
// sample value.
//
// For example:
//
// c.Assert(value, FitsTypeOf, int64(0))
// c.Assert(value, FitsTypeOf, os.Error(nil))
//
var FitsTypeOf Checker = &fitsTypeChecker{
&CheckerInfo{Name: "FitsTypeOf", Params: []string{"obtained", "sample"}},
}
func (checker *fitsTypeChecker) Check(params []interface{}, names []string) (result bool, error string) {
obtained := reflect.ValueOf(params[0])
sample := reflect.ValueOf(params[1])
if !obtained.IsValid() {
return false, ""
}
if !sample.IsValid() {
return false, "Invalid sample value"
}
return obtained.Type().AssignableTo(sample.Type()), ""
}
// -----------------------------------------------------------------------
// Implements checker.
type implementsChecker struct {
*CheckerInfo
}
// The Implements checker verifies that the obtained value
// implements the interface specified via a pointer to an interface
// variable.
//
// For example:
//
// var e os.Error
// c.Assert(err, Implements, &e)
//
var Implements Checker = &implementsChecker{
&CheckerInfo{Name: "Implements", Params: []string{"obtained", "ifaceptr"}},
}
func (checker *implementsChecker) Check(params []interface{}, names []string) (result bool, error string) {
obtained := reflect.ValueOf(params[0])
ifaceptr := reflect.ValueOf(params[1])
if !obtained.IsValid() {
return false, ""
}
if !ifaceptr.IsValid() || ifaceptr.Kind() != reflect.Ptr || ifaceptr.Elem().Kind() != reflect.Interface {
return false, "ifaceptr should be a pointer to an interface variable"
}
return obtained.Type().Implements(ifaceptr.Elem().Type()), ""
}

View File

@ -0,0 +1,131 @@
// Extensions to the go-check unittest framework.
//
// NOTE: see https://github.com/go-check/check/pull/6 for reasons why these
// checkers live here.
package check
import (
"bytes"
"reflect"
)
// -----------------------------------------------------------------------
// IsTrue / IsFalse checker.
type isBoolValueChecker struct {
*CheckerInfo
expected bool
}
func (checker *isBoolValueChecker) Check(
params []interface{},
names []string) (
result bool,
error string) {
obtained, ok := params[0].(bool)
if !ok {
return false, "Argument to " + checker.Name + " must be bool"
}
return obtained == checker.expected, ""
}
// The IsTrue checker verifies that the obtained value is true.
//
// For example:
//
// c.Assert(value, IsTrue)
//
var IsTrue Checker = &isBoolValueChecker{
&CheckerInfo{Name: "IsTrue", Params: []string{"obtained"}},
true,
}
// The IsFalse checker verifies that the obtained value is false.
//
// For example:
//
// c.Assert(value, IsFalse)
//
var IsFalse Checker = &isBoolValueChecker{
&CheckerInfo{Name: "IsFalse", Params: []string{"obtained"}},
false,
}
// -----------------------------------------------------------------------
// BytesEquals checker.
type bytesEquals struct{}
func (b *bytesEquals) Check(params []interface{}, names []string) (bool, string) {
if len(params) != 2 {
return false, "BytesEqual takes 2 bytestring arguments"
}
b1, ok1 := params[0].([]byte)
b2, ok2 := params[1].([]byte)
if !(ok1 && ok2) {
return false, "Arguments to BytesEqual must both be bytestrings"
}
return bytes.Equal(b1, b2), ""
}
func (b *bytesEquals) Info() *CheckerInfo {
return &CheckerInfo{
Name: "BytesEquals",
Params: []string{"bytes_one", "bytes_two"},
}
}
// BytesEquals checker compares two bytes sequence using bytes.Equal.
//
// For example:
//
// c.Assert(b, BytesEquals, []byte("bar"))
//
// Main difference between DeepEquals and BytesEquals is that BytesEquals treats
// `nil` as empty byte sequence while DeepEquals doesn't.
//
// c.Assert(nil, BytesEquals, []byte("")) // succeeds
// c.Assert(nil, DeepEquals, []byte("")) // fails
var BytesEquals = &bytesEquals{}
// -----------------------------------------------------------------------
// HasKey checker.
type hasKey struct{}
func (h *hasKey) Check(params []interface{}, names []string) (bool, string) {
if len(params) != 2 {
return false, "HasKey takes 2 arguments: a map and a key"
}
mapValue := reflect.ValueOf(params[0])
if mapValue.Kind() != reflect.Map {
return false, "First argument to HasKey must be a map"
}
keyValue := reflect.ValueOf(params[1])
if !keyValue.Type().AssignableTo(mapValue.Type().Key()) {
return false, "Second argument must be assignable to the map key type"
}
return mapValue.MapIndex(keyValue).IsValid(), ""
}
func (h *hasKey) Info() *CheckerInfo {
return &CheckerInfo{
Name: "HasKey",
Params: []string{"obtained", "key"},
}
}
// The HasKey checker verifies that the obtained map contains the given key.
//
// For example:
//
// c.Assert(myMap, HasKey, "foo")
//
var HasKey = &hasKey{}

View File

@ -0,0 +1,161 @@
package check
import (
"bytes"
"fmt"
"reflect"
"time"
)
type compareFunc func(v1 interface{}, v2 interface{}) (bool, error)
type valueCompare struct {
Name string
Func compareFunc
Operator string
}
// v1 and v2 must have the same type
// return >0 if v1 > v2
// return 0 if v1 = v2
// return <0 if v1 < v2
// now we only support int, uint, float64, string and []byte comparison
func compare(v1 interface{}, v2 interface{}) (int, error) {
value1 := reflect.ValueOf(v1)
value2 := reflect.ValueOf(v2)
switch v1.(type) {
case int, int8, int16, int32, int64:
a1 := value1.Int()
a2 := value2.Int()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case uint, uint8, uint16, uint32, uint64:
a1 := value1.Uint()
a2 := value2.Uint()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case float32, float64:
a1 := value1.Float()
a2 := value2.Float()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case string:
a1 := value1.String()
a2 := value2.String()
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
case []byte:
a1 := value1.Bytes()
a2 := value2.Bytes()
return bytes.Compare(a1, a2), nil
case time.Time:
a1 := v1.(time.Time)
a2 := v2.(time.Time)
if a1.After(a2) {
return 1, nil
} else if a1.Equal(a2) {
return 0, nil
}
return -1, nil
case time.Duration:
a1 := v1.(time.Duration)
a2 := v2.(time.Duration)
if a1 > a2 {
return 1, nil
} else if a1 == a2 {
return 0, nil
}
return -1, nil
default:
return 0, fmt.Errorf("type %T is not supported now", v1)
}
}
func less(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n < 0, nil
}
func lessEqual(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n <= 0, nil
}
func greater(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n > 0, nil
}
func greaterEqual(v1 interface{}, v2 interface{}) (bool, error) {
n, err := compare(v1, v2)
if err != nil {
return false, err
}
return n >= 0, nil
}
func (v *valueCompare) Check(params []interface{}, names []string) (bool, string) {
if len(params) != 2 {
return false, fmt.Sprintf("%s needs 2 arguments", v.Name)
}
v1 := params[0]
v2 := params[1]
v1Type := reflect.TypeOf(v1)
v2Type := reflect.TypeOf(v2)
if v1Type.Kind() != v2Type.Kind() {
return false, fmt.Sprintf("%s needs two same type, but %s != %s", v.Name, v1Type.Kind(), v2Type.Kind())
}
b, err := v.Func(v1, v2)
if err != nil {
return false, fmt.Sprintf("%s check err %v", v.Name, err)
}
return b, ""
}
func (v *valueCompare) Info() *CheckerInfo {
return &CheckerInfo{
Name: v.Name,
Params: []string{"compare_one", "compare_two"},
}
}
var Less = &valueCompare{Name: "Less", Func: less, Operator: "<"}
var LessEqual = &valueCompare{Name: "LessEqual", Func: lessEqual, Operator: "<="}
var Greater = &valueCompare{Name: "Greater", Func: greater, Operator: ">"}
var GreaterEqual = &valueCompare{Name: "GreaterEqual", Func: greaterEqual, Operator: ">="}

View File

@ -0,0 +1,236 @@
package check
import (
"fmt"
"strings"
"time"
)
// TestName returns the current test name in the form "SuiteName.TestName"
func (c *C) TestName() string {
return c.testName
}
// -----------------------------------------------------------------------
// Basic succeeding/failing logic.
// Failed returns whether the currently running test has already failed.
func (c *C) Failed() bool {
return c.status() == failedSt
}
// Fail marks the currently running test as failed.
//
// Something ought to have been previously logged so the developer can tell
// what went wrong. The higher level helper functions will fail the test
// and do the logging properly.
func (c *C) Fail() {
c.setStatus(failedSt)
}
// FailNow marks the currently running test as failed and stops running it.
// Something ought to have been previously logged so the developer can tell
// what went wrong. The higher level helper functions will fail the test
// and do the logging properly.
func (c *C) FailNow() {
c.Fail()
c.stopNow()
}
// Succeed marks the currently running test as succeeded, undoing any
// previous failures.
func (c *C) Succeed() {
c.setStatus(succeededSt)
}
// SucceedNow marks the currently running test as succeeded, undoing any
// previous failures, and stops running the test.
func (c *C) SucceedNow() {
c.Succeed()
c.stopNow()
}
// ExpectFailure informs that the running test is knowingly broken for
// the provided reason. If the test does not fail, an error will be reported
// to raise attention to this fact. This method is useful to temporarily
// disable tests which cover well known problems until a better time to
// fix the problem is found, without forgetting about the fact that a
// failure still exists.
func (c *C) ExpectFailure(reason string) {
if reason == "" {
panic("Missing reason why the test is expected to fail")
}
c.mustFail = true
c.reason = reason
}
// Skip skips the running test for the provided reason. If run from within
// SetUpTest, the individual test being set up will be skipped, and if run
// from within SetUpSuite, the whole suite is skipped.
func (c *C) Skip(reason string) {
if reason == "" {
panic("Missing reason why the test is being skipped")
}
c.reason = reason
c.setStatus(skippedSt)
c.stopNow()
}
// Parallel will mark the test run parallel within a test suite.
func (c *C) Parallel() {
c.parallel <- c
}
// -----------------------------------------------------------------------
// Basic logging.
// GetTestLog returns the current test error output.
func (c *C) GetTestLog() string {
return c.logb.String()
}
// Log logs some information into the test error output.
// The provided arguments are assembled together into a string with fmt.Sprint.
func (c *C) Log(args ...interface{}) {
c.log(args...)
}
// Log logs some information into the test error output.
// The provided arguments are assembled together into a string with fmt.Sprintf.
func (c *C) Logf(format string, args ...interface{}) {
c.logf(format, args...)
}
// Output enables *C to be used as a logger in functions that require only
// the minimum interface of *log.Logger.
func (c *C) Output(calldepth int, s string) error {
d := time.Now().Sub(c.startTime)
msec := d / time.Millisecond
sec := d / time.Second
min := d / time.Minute
c.Logf("[LOG] %d:%02d.%03d %s", min, sec%60, msec%1000, s)
return nil
}
// Error logs an error into the test error output and marks the test as failed.
// The provided arguments are assembled together into a string with fmt.Sprint.
func (c *C) Error(args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprint("Error: ", fmt.Sprint(args...)))
c.logNewLine()
c.Fail()
}
// Errorf logs an error into the test error output and marks the test as failed.
// The provided arguments are assembled together into a string with fmt.Sprintf.
func (c *C) Errorf(format string, args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprintf("Error: "+format, args...))
c.logNewLine()
c.Fail()
}
// Fatal logs an error into the test error output, marks the test as failed, and
// stops the test execution. The provided arguments are assembled together into
// a string with fmt.Sprint.
func (c *C) Fatal(args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprint("Error: ", fmt.Sprint(args...)))
c.logNewLine()
c.FailNow()
}
// Fatlaf logs an error into the test error output, marks the test as failed, and
// stops the test execution. The provided arguments are assembled together into
// a string with fmt.Sprintf.
func (c *C) Fatalf(format string, args ...interface{}) {
c.logCaller(1)
c.logString(fmt.Sprint("Error: ", fmt.Sprintf(format, args...)))
c.logNewLine()
c.FailNow()
}
// -----------------------------------------------------------------------
// Generic checks and assertions based on checkers.
// Check verifies if the first value matches the expected value according
// to the provided checker. If they do not match, an error is logged, the
// test is marked as failed, and the test execution continues.
//
// Some checkers may not need the expected argument (e.g. IsNil).
//
// Extra arguments provided to the function are logged next to the reported
// problem when the matching fails.
func (c *C) Check(obtained interface{}, checker Checker, args ...interface{}) bool {
return c.internalCheck("Check", obtained, checker, args...)
}
// Assert ensures that the first value matches the expected value according
// to the provided checker. If they do not match, an error is logged, the
// test is marked as failed, and the test execution stops.
//
// Some checkers may not need the expected argument (e.g. IsNil).
//
// Extra arguments provided to the function are logged next to the reported
// problem when the matching fails.
func (c *C) Assert(obtained interface{}, checker Checker, args ...interface{}) {
if !c.internalCheck("Assert", obtained, checker, args...) {
c.stopNow()
}
}
func (c *C) internalCheck(funcName string, obtained interface{}, checker Checker, args ...interface{}) bool {
if checker == nil {
c.logCaller(2)
c.logString(fmt.Sprintf("%s(obtained, nil!?, ...):", funcName))
c.logString("Oops.. you've provided a nil checker!")
c.logNewLine()
c.Fail()
return false
}
// If the last argument is a bug info, extract it out.
var comment CommentInterface
if len(args) > 0 {
if c, ok := args[len(args)-1].(CommentInterface); ok {
comment = c
args = args[:len(args)-1]
}
}
params := append([]interface{}{obtained}, args...)
info := checker.Info()
if len(params) != len(info.Params) {
names := append([]string{info.Params[0], info.Name}, info.Params[1:]...)
c.logCaller(2)
c.logString(fmt.Sprintf("%s(%s):", funcName, strings.Join(names, ", ")))
c.logString(fmt.Sprintf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(names), len(params)+1))
c.logNewLine()
c.Fail()
return false
}
// Copy since it may be mutated by Check.
names := append([]string{}, info.Params...)
// Do the actual check.
result, error := checker.Check(params, names)
if !result || error != "" {
c.logCaller(2)
for i := 0; i != len(params); i++ {
c.logValue(names[i], params[i])
}
if comment != nil {
c.logString(comment.CheckCommentString())
}
if error != "" {
c.logString(error)
}
c.logNewLine()
c.Fail()
return false
}
return true
}

View File

@ -0,0 +1,168 @@
package check
import (
"bytes"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
)
func indent(s, with string) (r string) {
eol := true
for i := 0; i != len(s); i++ {
c := s[i]
switch {
case eol && c == '\n' || c == '\r':
case c == '\n' || c == '\r':
eol = true
case eol:
eol = false
s = s[:i] + with + s[i:]
i += len(with)
}
}
return s
}
func printLine(filename string, line int) (string, error) {
fset := token.NewFileSet()
file, err := os.Open(filename)
if err != nil {
return "", err
}
fnode, err := parser.ParseFile(fset, filename, file, parser.ParseComments)
if err != nil {
return "", err
}
config := &printer.Config{Mode: printer.UseSpaces, Tabwidth: 4}
lp := &linePrinter{fset: fset, fnode: fnode, line: line, config: config}
ast.Walk(lp, fnode)
result := lp.output.Bytes()
// Comments leave \n at the end.
n := len(result)
for n > 0 && result[n-1] == '\n' {
n--
}
return string(result[:n]), nil
}
type linePrinter struct {
config *printer.Config
fset *token.FileSet
fnode *ast.File
line int
output bytes.Buffer
stmt ast.Stmt
}
func (lp *linePrinter) emit() bool {
if lp.stmt != nil {
lp.trim(lp.stmt)
lp.printWithComments(lp.stmt)
lp.stmt = nil
return true
}
return false
}
func (lp *linePrinter) printWithComments(n ast.Node) {
nfirst := lp.fset.Position(n.Pos()).Line
nlast := lp.fset.Position(n.End()).Line
for _, g := range lp.fnode.Comments {
cfirst := lp.fset.Position(g.Pos()).Line
clast := lp.fset.Position(g.End()).Line
if clast == nfirst-1 && lp.fset.Position(n.Pos()).Column == lp.fset.Position(g.Pos()).Column {
for _, c := range g.List {
lp.output.WriteString(c.Text)
lp.output.WriteByte('\n')
}
}
if cfirst >= nfirst && cfirst <= nlast && n.End() <= g.List[0].Slash {
// The printer will not include the comment if it starts past
// the node itself. Trick it into printing by overlapping the
// slash with the end of the statement.
g.List[0].Slash = n.End() - 1
}
}
node := &printer.CommentedNode{n, lp.fnode.Comments}
lp.config.Fprint(&lp.output, lp.fset, node)
}
func (lp *linePrinter) Visit(n ast.Node) (w ast.Visitor) {
if n == nil {
if lp.output.Len() == 0 {
lp.emit()
}
return nil
}
first := lp.fset.Position(n.Pos()).Line
last := lp.fset.Position(n.End()).Line
if first <= lp.line && last >= lp.line {
// Print the innermost statement containing the line.
if stmt, ok := n.(ast.Stmt); ok {
if _, ok := n.(*ast.BlockStmt); !ok {
lp.stmt = stmt
}
}
if first == lp.line && lp.emit() {
return nil
}
return lp
}
return nil
}
func (lp *linePrinter) trim(n ast.Node) bool {
stmt, ok := n.(ast.Stmt)
if !ok {
return true
}
line := lp.fset.Position(n.Pos()).Line
if line != lp.line {
return false
}
switch stmt := stmt.(type) {
case *ast.IfStmt:
stmt.Body = lp.trimBlock(stmt.Body)
case *ast.SwitchStmt:
stmt.Body = lp.trimBlock(stmt.Body)
case *ast.TypeSwitchStmt:
stmt.Body = lp.trimBlock(stmt.Body)
case *ast.CaseClause:
stmt.Body = lp.trimList(stmt.Body)
case *ast.CommClause:
stmt.Body = lp.trimList(stmt.Body)
case *ast.BlockStmt:
stmt.List = lp.trimList(stmt.List)
}
return true
}
func (lp *linePrinter) trimBlock(stmt *ast.BlockStmt) *ast.BlockStmt {
if !lp.trim(stmt) {
return lp.emptyBlock(stmt)
}
stmt.Rbrace = stmt.Lbrace
return stmt
}
func (lp *linePrinter) trimList(stmts []ast.Stmt) []ast.Stmt {
for i := 0; i != len(stmts); i++ {
if !lp.trim(stmts[i]) {
stmts[i] = lp.emptyStmt(stmts[i])
break
}
}
return stmts
}
func (lp *linePrinter) emptyStmt(n ast.Node) *ast.ExprStmt {
return &ast.ExprStmt{&ast.Ellipsis{n.Pos(), nil}}
}
func (lp *linePrinter) emptyBlock(n ast.Node) *ast.BlockStmt {
p := n.Pos()
return &ast.BlockStmt{p, []ast.Stmt{lp.emptyStmt(n)}, p}
}

View File

@ -0,0 +1,179 @@
package check
import (
"bufio"
"flag"
"fmt"
"os"
"testing"
"time"
)
// -----------------------------------------------------------------------
// Test suite registry.
var allSuites []interface{}
// Suite registers the given value as a test suite to be run. Any methods
// starting with the Test prefix in the given value will be considered as
// a test method.
func Suite(suite interface{}) interface{} {
allSuites = append(allSuites, suite)
return suite
}
// -----------------------------------------------------------------------
// Public running interface.
var (
oldFilterFlag = flag.String("gocheck.f", "", "Regular expression selecting which tests and/or suites to run")
oldVerboseFlag = flag.Bool("gocheck.v", false, "Verbose mode")
oldStreamFlag = flag.Bool("gocheck.vv", false, "Super verbose mode (disables output caching)")
oldBenchFlag = flag.Bool("gocheck.b", false, "Run benchmarks")
oldBenchTime = flag.Duration("gocheck.btime", 1*time.Second, "approximate run time for each benchmark")
oldListFlag = flag.Bool("gocheck.list", false, "List the names of all tests that will be run")
oldWorkFlag = flag.Bool("gocheck.work", false, "Display and do not remove the test working directory")
newFilterFlag = flag.String("check.f", "", "Regular expression selecting which tests and/or suites to run")
newVerboseFlag = flag.Bool("check.v", false, "Verbose mode")
newStreamFlag = flag.Bool("check.vv", false, "Super verbose mode (disables output caching)")
newBenchFlag = flag.Bool("check.b", false, "Run benchmarks")
newBenchTime = flag.Duration("check.btime", 1*time.Second, "approximate run time for each benchmark")
newBenchMem = flag.Bool("check.bmem", false, "Report memory benchmarks")
newListFlag = flag.Bool("check.list", false, "List the names of all tests that will be run")
newWorkFlag = flag.Bool("check.work", false, "Display and do not remove the test working directory")
newExcludeFlag = flag.String("check.exclude", "", "Regular expression to exclude tests to run")
)
var CustomVerboseFlag bool
// TestingT runs all test suites registered with the Suite function,
// printing results to stdout, and reporting any failures back to
// the "testing" package.
func TestingT(testingT *testing.T) {
benchTime := *newBenchTime
if benchTime == 1*time.Second {
benchTime = *oldBenchTime
}
conf := &RunConf{
Filter: *oldFilterFlag + *newFilterFlag,
Verbose: *oldVerboseFlag || *newVerboseFlag || CustomVerboseFlag,
Stream: *oldStreamFlag || *newStreamFlag,
Benchmark: *oldBenchFlag || *newBenchFlag,
BenchmarkTime: benchTime,
BenchmarkMem: *newBenchMem,
KeepWorkDir: *oldWorkFlag || *newWorkFlag,
Exclude: *newExcludeFlag,
}
if *oldListFlag || *newListFlag {
w := bufio.NewWriter(os.Stdout)
for _, name := range ListAll(conf) {
fmt.Fprintln(w, name)
}
w.Flush()
return
}
result := RunAll(conf)
println(result.String())
if !result.Passed() {
testingT.Fail()
}
}
// RunAll runs all test suites registered with the Suite function, using the
// provided run configuration.
func RunAll(runConf *RunConf) *Result {
result := Result{}
for _, suite := range allSuites {
result.Add(Run(suite, runConf))
}
return &result
}
// Run runs the provided test suite using the provided run configuration.
func Run(suite interface{}, runConf *RunConf) *Result {
runner := newSuiteRunner(suite, runConf)
return runner.run()
}
// ListAll returns the names of all the test functions registered with the
// Suite function that will be run with the provided run configuration.
func ListAll(runConf *RunConf) []string {
var names []string
for _, suite := range allSuites {
names = append(names, List(suite, runConf)...)
}
return names
}
// List returns the names of the test functions in the given
// suite that will be run with the provided run configuration.
func List(suite interface{}, runConf *RunConf) []string {
var names []string
runner := newSuiteRunner(suite, runConf)
for _, t := range runner.tests {
names = append(names, t.String())
}
return names
}
// -----------------------------------------------------------------------
// Result methods.
func (r *Result) Add(other *Result) {
r.Succeeded += other.Succeeded
r.Skipped += other.Skipped
r.Failed += other.Failed
r.Panicked += other.Panicked
r.FixturePanicked += other.FixturePanicked
r.ExpectedFailures += other.ExpectedFailures
r.Missed += other.Missed
if r.WorkDir != "" && other.WorkDir != "" {
r.WorkDir += ":" + other.WorkDir
} else if other.WorkDir != "" {
r.WorkDir = other.WorkDir
}
}
func (r *Result) Passed() bool {
return (r.Failed == 0 && r.Panicked == 0 &&
r.FixturePanicked == 0 && r.Missed == 0 &&
r.RunError == nil)
}
func (r *Result) String() string {
if r.RunError != nil {
return "ERROR: " + r.RunError.Error()
}
var value string
if r.Failed == 0 && r.Panicked == 0 && r.FixturePanicked == 0 &&
r.Missed == 0 {
value = "OK: "
} else {
value = "OOPS: "
}
value += fmt.Sprintf("%d passed", r.Succeeded)
if r.Skipped != 0 {
value += fmt.Sprintf(", %d skipped", r.Skipped)
}
if r.ExpectedFailures != 0 {
value += fmt.Sprintf(", %d expected failures", r.ExpectedFailures)
}
if r.Failed != 0 {
value += fmt.Sprintf(", %d FAILED", r.Failed)
}
if r.Panicked != 0 {
value += fmt.Sprintf(", %d PANICKED", r.Panicked)
}
if r.FixturePanicked != 0 {
value += fmt.Sprintf(", %d FIXTURE-PANICKED", r.FixturePanicked)
}
if r.Missed != 0 {
value += fmt.Sprintf(", %d MISSED", r.Missed)
}
if r.WorkDir != "" {
value += "\nWORK=" + r.WorkDir
}
return value
}

View File

@ -0,0 +1,20 @@
Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,206 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"bytes"
"encoding/hex"
"fmt"
)
// FromBytes returns UUID converted from raw byte slice input.
// It will return error if the slice isn't 16 bytes long.
func FromBytes(input []byte) (u UUID, err error) {
err = u.UnmarshalBinary(input)
return
}
// FromBytesOrNil returns UUID converted from raw byte slice input.
// Same behavior as FromBytes, but returns a Nil UUID on error.
func FromBytesOrNil(input []byte) UUID {
uuid, err := FromBytes(input)
if err != nil {
return Nil
}
return uuid
}
// FromString returns UUID parsed from string input.
// Input is expected in a form accepted by UnmarshalText.
func FromString(input string) (u UUID, err error) {
err = u.UnmarshalText([]byte(input))
return
}
// FromStringOrNil returns UUID parsed from string input.
// Same behavior as FromString, but returns a Nil UUID on error.
func FromStringOrNil(input string) UUID {
uuid, err := FromString(input)
if err != nil {
return Nil
}
return uuid
}
// MarshalText implements the encoding.TextMarshaler interface.
// The encoding is the same as returned by String.
func (u UUID) MarshalText() (text []byte, err error) {
text = []byte(u.String())
return
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// Following formats are supported:
// "6ba7b810-9dad-11d1-80b4-00c04fd430c8",
// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}",
// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8"
// "6ba7b8109dad11d180b400c04fd430c8"
// ABNF for supported UUID text representation follows:
// uuid := canonical | hashlike | braced | urn
// plain := canonical | hashlike
// canonical := 4hexoct '-' 2hexoct '-' 2hexoct '-' 6hexoct
// hashlike := 12hexoct
// braced := '{' plain '}'
// urn := URN ':' UUID-NID ':' plain
// URN := 'urn'
// UUID-NID := 'uuid'
// 12hexoct := 6hexoct 6hexoct
// 6hexoct := 4hexoct 2hexoct
// 4hexoct := 2hexoct 2hexoct
// 2hexoct := hexoct hexoct
// hexoct := hexdig hexdig
// hexdig := '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9' |
// 'a' | 'b' | 'c' | 'd' | 'e' | 'f' |
// 'A' | 'B' | 'C' | 'D' | 'E' | 'F'
func (u *UUID) UnmarshalText(text []byte) (err error) {
switch len(text) {
case 32:
return u.decodeHashLike(text)
case 36:
return u.decodeCanonical(text)
case 38:
return u.decodeBraced(text)
case 41:
fallthrough
case 45:
return u.decodeURN(text)
default:
return fmt.Errorf("uuid: incorrect UUID length: %s", text)
}
}
// decodeCanonical decodes UUID string in format
// "6ba7b810-9dad-11d1-80b4-00c04fd430c8".
func (u *UUID) decodeCanonical(t []byte) (err error) {
if t[8] != '-' || t[13] != '-' || t[18] != '-' || t[23] != '-' {
return fmt.Errorf("uuid: incorrect UUID format %s", t)
}
src := t[:]
dst := u[:]
for i, byteGroup := range byteGroups {
if i > 0 {
src = src[1:] // skip dash
}
_, err = hex.Decode(dst[:byteGroup/2], src[:byteGroup])
if err != nil {
return
}
src = src[byteGroup:]
dst = dst[byteGroup/2:]
}
return
}
// decodeHashLike decodes UUID string in format
// "6ba7b8109dad11d180b400c04fd430c8".
func (u *UUID) decodeHashLike(t []byte) (err error) {
src := t[:]
dst := u[:]
if _, err = hex.Decode(dst, src); err != nil {
return err
}
return
}
// decodeBraced decodes UUID string in format
// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}" or in format
// "{6ba7b8109dad11d180b400c04fd430c8}".
func (u *UUID) decodeBraced(t []byte) (err error) {
l := len(t)
if t[0] != '{' || t[l-1] != '}' {
return fmt.Errorf("uuid: incorrect UUID format %s", t)
}
return u.decodePlain(t[1 : l-1])
}
// decodeURN decodes UUID string in format
// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" or in format
// "urn:uuid:6ba7b8109dad11d180b400c04fd430c8".
func (u *UUID) decodeURN(t []byte) (err error) {
total := len(t)
urn_uuid_prefix := t[:9]
if !bytes.Equal(urn_uuid_prefix, urnPrefix) {
return fmt.Errorf("uuid: incorrect UUID format: %s", t)
}
return u.decodePlain(t[9:total])
}
// decodePlain decodes UUID string in canonical format
// "6ba7b810-9dad-11d1-80b4-00c04fd430c8" or in hash-like format
// "6ba7b8109dad11d180b400c04fd430c8".
func (u *UUID) decodePlain(t []byte) (err error) {
switch len(t) {
case 32:
return u.decodeHashLike(t)
case 36:
return u.decodeCanonical(t)
default:
return fmt.Errorf("uuid: incorrrect UUID length: %s", t)
}
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (u UUID) MarshalBinary() (data []byte, err error) {
data = u.Bytes()
return
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It will return error if the slice isn't 16 bytes long.
func (u *UUID) UnmarshalBinary(data []byte) (err error) {
if len(data) != Size {
err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data))
return
}
copy(u[:], data)
return
}

View File

@ -0,0 +1,239 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"hash"
"net"
"os"
"sync"
"time"
)
// Difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970).
const epochStart = 122192928000000000
var (
global = newDefaultGenerator()
epochFunc = unixTimeFunc
posixUID = uint32(os.Getuid())
posixGID = uint32(os.Getgid())
)
// NewV1 returns UUID based on current timestamp and MAC address.
func NewV1() UUID {
return global.NewV1()
}
// NewV2 returns DCE Security UUID based on POSIX UID/GID.
func NewV2(domain byte) UUID {
return global.NewV2(domain)
}
// NewV3 returns UUID based on MD5 hash of namespace UUID and name.
func NewV3(ns UUID, name string) UUID {
return global.NewV3(ns, name)
}
// NewV4 returns random generated UUID.
func NewV4() UUID {
return global.NewV4()
}
// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name.
func NewV5(ns UUID, name string) UUID {
return global.NewV5(ns, name)
}
// Generator provides interface for generating UUIDs.
type Generator interface {
NewV1() UUID
NewV2(domain byte) UUID
NewV3(ns UUID, name string) UUID
NewV4() UUID
NewV5(ns UUID, name string) UUID
}
// Default generator implementation.
type generator struct {
storageOnce sync.Once
storageMutex sync.Mutex
lastTime uint64
clockSequence uint16
hardwareAddr [6]byte
}
func newDefaultGenerator() Generator {
return &generator{}
}
// NewV1 returns UUID based on current timestamp and MAC address.
func (g *generator) NewV1() UUID {
u := UUID{}
timeNow, clockSeq, hardwareAddr := g.getStorage()
binary.BigEndian.PutUint32(u[0:], uint32(timeNow))
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
copy(u[10:], hardwareAddr)
u.SetVersion(V1)
u.SetVariant(VariantRFC4122)
return u
}
// NewV2 returns DCE Security UUID based on POSIX UID/GID.
func (g *generator) NewV2(domain byte) UUID {
u := UUID{}
timeNow, clockSeq, hardwareAddr := g.getStorage()
switch domain {
case DomainPerson:
binary.BigEndian.PutUint32(u[0:], posixUID)
case DomainGroup:
binary.BigEndian.PutUint32(u[0:], posixGID)
}
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
u[9] = domain
copy(u[10:], hardwareAddr)
u.SetVersion(V2)
u.SetVariant(VariantRFC4122)
return u
}
// NewV3 returns UUID based on MD5 hash of namespace UUID and name.
func (g *generator) NewV3(ns UUID, name string) UUID {
u := newFromHash(md5.New(), ns, name)
u.SetVersion(V3)
u.SetVariant(VariantRFC4122)
return u
}
// NewV4 returns random generated UUID.
func (g *generator) NewV4() UUID {
u := UUID{}
g.safeRandom(u[:])
u.SetVersion(V4)
u.SetVariant(VariantRFC4122)
return u
}
// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name.
func (g *generator) NewV5(ns UUID, name string) UUID {
u := newFromHash(sha1.New(), ns, name)
u.SetVersion(V5)
u.SetVariant(VariantRFC4122)
return u
}
func (g *generator) initStorage() {
g.initClockSequence()
g.initHardwareAddr()
}
func (g *generator) initClockSequence() {
buf := make([]byte, 2)
g.safeRandom(buf)
g.clockSequence = binary.BigEndian.Uint16(buf)
}
func (g *generator) initHardwareAddr() {
interfaces, err := net.Interfaces()
if err == nil {
for _, iface := range interfaces {
if len(iface.HardwareAddr) >= 6 {
copy(g.hardwareAddr[:], iface.HardwareAddr)
return
}
}
}
// Initialize hardwareAddr randomly in case
// of real network interfaces absence
g.safeRandom(g.hardwareAddr[:])
// Set multicast bit as recommended in RFC 4122
g.hardwareAddr[0] |= 0x01
}
func (g *generator) safeRandom(dest []byte) {
if _, err := rand.Read(dest); err != nil {
panic(err)
}
}
// Returns UUID v1/v2 storage state.
// Returns epoch timestamp, clock sequence, and hardware address.
func (g *generator) getStorage() (uint64, uint16, []byte) {
g.storageOnce.Do(g.initStorage)
g.storageMutex.Lock()
defer g.storageMutex.Unlock()
timeNow := epochFunc()
// Clock changed backwards since last UUID generation.
// Should increase clock sequence.
if timeNow <= g.lastTime {
g.clockSequence++
}
g.lastTime = timeNow
return timeNow, g.clockSequence, g.hardwareAddr[:]
}
// Returns difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and current time.
// This is default epoch calculation function.
func unixTimeFunc() uint64 {
return epochStart + uint64(time.Now().UnixNano()/100)
}
// Returns UUID based on hashing of namespace UUID and name.
func newFromHash(h hash.Hash, ns UUID, name string) UUID {
u := UUID{}
h.Write(ns[:])
h.Write([]byte(name))
copy(u[:], h.Sum(nil))
return u
}

View File

@ -0,0 +1,78 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"database/sql/driver"
"fmt"
)
// Value implements the driver.Valuer interface.
func (u UUID) Value() (driver.Value, error) {
return u.String(), nil
}
// Scan implements the sql.Scanner interface.
// A 16-byte slice is handled by UnmarshalBinary, while
// a longer byte slice or a string is handled by UnmarshalText.
func (u *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
if len(src) == Size {
return u.UnmarshalBinary(src)
}
return u.UnmarshalText(src)
case string:
return u.UnmarshalText([]byte(src))
}
return fmt.Errorf("uuid: cannot convert %T to UUID", src)
}
// NullUUID can be used with the standard sql package to represent a
// UUID value that can be NULL in the database
type NullUUID struct {
UUID UUID
Valid bool
}
// Value implements the driver.Valuer interface.
func (u NullUUID) Value() (driver.Value, error) {
if !u.Valid {
return nil, nil
}
// Delegate to UUID Value function
return u.UUID.Value()
}
// Scan implements the sql.Scanner interface.
func (u *NullUUID) Scan(src interface{}) error {
if src == nil {
u.UUID, u.Valid = Nil, false
return nil
}
// Delegate to UUID Scan function
u.Valid = true
return u.UUID.Scan(src)
}

View File

@ -0,0 +1,161 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Package uuid provides implementation of Universally Unique Identifier (UUID).
// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and
// version 2 (as specified in DCE 1.1).
package uuid
import (
"bytes"
"encoding/hex"
)
// Size of a UUID in bytes.
const Size = 16
// UUID representation compliant with specification
// described in RFC 4122.
type UUID [Size]byte
// UUID versions
const (
_ byte = iota
V1
V2
V3
V4
V5
)
// UUID layout variants.
const (
VariantNCS byte = iota
VariantRFC4122
VariantMicrosoft
VariantFuture
)
// UUID DCE domains.
const (
DomainPerson = iota
DomainGroup
DomainOrg
)
// String parse helpers.
var (
urnPrefix = []byte("urn:uuid:")
byteGroups = []int{8, 4, 4, 4, 12}
)
// Nil is special form of UUID that is specified to have all
// 128 bits set to zero.
var Nil = UUID{}
// Predefined namespace UUIDs.
var (
NamespaceDNS = Must(FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))
NamespaceURL = Must(FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8"))
NamespaceOID = Must(FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8"))
NamespaceX500 = Must(FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8"))
)
// Equal returns true if u1 and u2 equals, otherwise returns false.
func Equal(u1 UUID, u2 UUID) bool {
return bytes.Equal(u1[:], u2[:])
}
// Version returns algorithm version used to generate UUID.
func (u UUID) Version() byte {
return u[6] >> 4
}
// Variant returns UUID layout variant.
func (u UUID) Variant() byte {
switch {
case (u[8] >> 7) == 0x00:
return VariantNCS
case (u[8] >> 6) == 0x02:
return VariantRFC4122
case (u[8] >> 5) == 0x06:
return VariantMicrosoft
case (u[8] >> 5) == 0x07:
fallthrough
default:
return VariantFuture
}
}
// Bytes returns bytes slice representation of UUID.
func (u UUID) Bytes() []byte {
return u[:]
}
// Returns canonical string representation of UUID:
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx.
func (u UUID) String() string {
buf := make([]byte, 36)
hex.Encode(buf[0:8], u[0:4])
buf[8] = '-'
hex.Encode(buf[9:13], u[4:6])
buf[13] = '-'
hex.Encode(buf[14:18], u[6:8])
buf[18] = '-'
hex.Encode(buf[19:23], u[8:10])
buf[23] = '-'
hex.Encode(buf[24:], u[10:])
return string(buf)
}
// SetVersion sets version bits.
func (u *UUID) SetVersion(v byte) {
u[6] = (u[6] & 0x0f) | (v << 4)
}
// SetVariant sets variant bits.
func (u *UUID) SetVariant(v byte) {
switch v {
case VariantNCS:
u[8] = (u[8]&(0xff>>1) | (0x00 << 7))
case VariantRFC4122:
u[8] = (u[8]&(0xff>>2) | (0x02 << 6))
case VariantMicrosoft:
u[8] = (u[8]&(0xff>>3) | (0x06 << 5))
case VariantFuture:
fallthrough
default:
u[8] = (u[8]&(0xff>>3) | (0x07 << 5))
}
}
// Must is a helper that wraps a call to a function returning (UUID, error)
// and panics if the error is non-nil. It is intended for use in variable
// initializations such as
// var packageUUID = uuid.Must(uuid.FromString("123e4567-e89b-12d3-a456-426655440000"));
func Must(u UUID, err error) UUID {
if err != nil {
panic(err)
}
return u
}

View File

@ -0,0 +1,45 @@
The MIT License (MIT)
Copyright (c) 2015 Spring, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
- Based on https://github.com/oguzbilgic/fpd, which has the following license:
"""
The MIT License (MIT)
Copyright (c) 2013 Oguz Bilgic
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

View File

@ -0,0 +1,414 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Multiprecision decimal numbers.
// For floating-point formatting only; not general purpose.
// Only operations are assign and (binary) left/right shift.
// Can do binary floating point in multiprecision decimal precisely
// because 2 divides 10; cannot do decimal floating point
// in multiprecision binary precisely.
package decimal
type decimal struct {
d [800]byte // digits, big-endian representation
nd int // number of digits used
dp int // decimal point
neg bool // negative flag
trunc bool // discarded nonzero digits beyond d[:nd]
}
func (a *decimal) String() string {
n := 10 + a.nd
if a.dp > 0 {
n += a.dp
}
if a.dp < 0 {
n += -a.dp
}
buf := make([]byte, n)
w := 0
switch {
case a.nd == 0:
return "0"
case a.dp <= 0:
// zeros fill space between decimal point and digits
buf[w] = '0'
w++
buf[w] = '.'
w++
w += digitZero(buf[w : w+-a.dp])
w += copy(buf[w:], a.d[0:a.nd])
case a.dp < a.nd:
// decimal point in middle of digits
w += copy(buf[w:], a.d[0:a.dp])
buf[w] = '.'
w++
w += copy(buf[w:], a.d[a.dp:a.nd])
default:
// zeros fill space between digits and decimal point
w += copy(buf[w:], a.d[0:a.nd])
w += digitZero(buf[w : w+a.dp-a.nd])
}
return string(buf[0:w])
}
func digitZero(dst []byte) int {
for i := range dst {
dst[i] = '0'
}
return len(dst)
}
// trim trailing zeros from number.
// (They are meaningless; the decimal point is tracked
// independent of the number of digits.)
func trim(a *decimal) {
for a.nd > 0 && a.d[a.nd-1] == '0' {
a.nd--
}
if a.nd == 0 {
a.dp = 0
}
}
// Assign v to a.
func (a *decimal) Assign(v uint64) {
var buf [24]byte
// Write reversed decimal in buf.
n := 0
for v > 0 {
v1 := v / 10
v -= 10 * v1
buf[n] = byte(v + '0')
n++
v = v1
}
// Reverse again to produce forward decimal in a.d.
a.nd = 0
for n--; n >= 0; n-- {
a.d[a.nd] = buf[n]
a.nd++
}
a.dp = a.nd
trim(a)
}
// Maximum shift that we can do in one pass without overflow.
// A uint has 32 or 64 bits, and we have to be able to accommodate 9<<k.
const uintSize = 32 << (^uint(0) >> 63)
const maxShift = uintSize - 4
// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow.
func rightShift(a *decimal, k uint) {
r := 0 // read pointer
w := 0 // write pointer
// Pick up enough leading digits to cover first shift.
var n uint
for ; n>>k == 0; r++ {
if r >= a.nd {
if n == 0 {
// a == 0; shouldn't get here, but handle anyway.
a.nd = 0
return
}
for n>>k == 0 {
n = n * 10
r++
}
break
}
c := uint(a.d[r])
n = n*10 + c - '0'
}
a.dp -= r - 1
var mask uint = (1 << k) - 1
// Pick up a digit, put down a digit.
for ; r < a.nd; r++ {
c := uint(a.d[r])
dig := n >> k
n &= mask
a.d[w] = byte(dig + '0')
w++
n = n*10 + c - '0'
}
// Put down extra digits.
for n > 0 {
dig := n >> k
n &= mask
if w < len(a.d) {
a.d[w] = byte(dig + '0')
w++
} else if dig > 0 {
a.trunc = true
}
n = n * 10
}
a.nd = w
trim(a)
}
// Cheat sheet for left shift: table indexed by shift count giving
// number of new digits that will be introduced by that shift.
//
// For example, leftcheats[4] = {2, "625"}. That means that
// if we are shifting by 4 (multiplying by 16), it will add 2 digits
// when the string prefix is "625" through "999", and one fewer digit
// if the string prefix is "000" through "624".
//
// Credit for this trick goes to Ken.
type leftCheat struct {
delta int // number of new digits
cutoff string // minus one digit if original < a.
}
var leftcheats = []leftCheat{
// Leading digits of 1/2^i = 5^i.
// 5^23 is not an exact 64-bit floating point number,
// so have to use bc for the math.
// Go up to 60 to be large enough for 32bit and 64bit platforms.
/*
seq 60 | sed 's/^/5^/' | bc |
awk 'BEGIN{ print "\t{ 0, \"\" }," }
{
log2 = log(2)/log(10)
printf("\t{ %d, \"%s\" },\t// * %d\n",
int(log2*NR+1), $0, 2**NR)
}'
*/
{0, ""},
{1, "5"}, // * 2
{1, "25"}, // * 4
{1, "125"}, // * 8
{2, "625"}, // * 16
{2, "3125"}, // * 32
{2, "15625"}, // * 64
{3, "78125"}, // * 128
{3, "390625"}, // * 256
{3, "1953125"}, // * 512
{4, "9765625"}, // * 1024
{4, "48828125"}, // * 2048
{4, "244140625"}, // * 4096
{4, "1220703125"}, // * 8192
{5, "6103515625"}, // * 16384
{5, "30517578125"}, // * 32768
{5, "152587890625"}, // * 65536
{6, "762939453125"}, // * 131072
{6, "3814697265625"}, // * 262144
{6, "19073486328125"}, // * 524288
{7, "95367431640625"}, // * 1048576
{7, "476837158203125"}, // * 2097152
{7, "2384185791015625"}, // * 4194304
{7, "11920928955078125"}, // * 8388608
{8, "59604644775390625"}, // * 16777216
{8, "298023223876953125"}, // * 33554432
{8, "1490116119384765625"}, // * 67108864
{9, "7450580596923828125"}, // * 134217728
{9, "37252902984619140625"}, // * 268435456
{9, "186264514923095703125"}, // * 536870912
{10, "931322574615478515625"}, // * 1073741824
{10, "4656612873077392578125"}, // * 2147483648
{10, "23283064365386962890625"}, // * 4294967296
{10, "116415321826934814453125"}, // * 8589934592
{11, "582076609134674072265625"}, // * 17179869184
{11, "2910383045673370361328125"}, // * 34359738368
{11, "14551915228366851806640625"}, // * 68719476736
{12, "72759576141834259033203125"}, // * 137438953472
{12, "363797880709171295166015625"}, // * 274877906944
{12, "1818989403545856475830078125"}, // * 549755813888
{13, "9094947017729282379150390625"}, // * 1099511627776
{13, "45474735088646411895751953125"}, // * 2199023255552
{13, "227373675443232059478759765625"}, // * 4398046511104
{13, "1136868377216160297393798828125"}, // * 8796093022208
{14, "5684341886080801486968994140625"}, // * 17592186044416
{14, "28421709430404007434844970703125"}, // * 35184372088832
{14, "142108547152020037174224853515625"}, // * 70368744177664
{15, "710542735760100185871124267578125"}, // * 140737488355328
{15, "3552713678800500929355621337890625"}, // * 281474976710656
{15, "17763568394002504646778106689453125"}, // * 562949953421312
{16, "88817841970012523233890533447265625"}, // * 1125899906842624
{16, "444089209850062616169452667236328125"}, // * 2251799813685248
{16, "2220446049250313080847263336181640625"}, // * 4503599627370496
{16, "11102230246251565404236316680908203125"}, // * 9007199254740992
{17, "55511151231257827021181583404541015625"}, // * 18014398509481984
{17, "277555756156289135105907917022705078125"}, // * 36028797018963968
{17, "1387778780781445675529539585113525390625"}, // * 72057594037927936
{18, "6938893903907228377647697925567626953125"}, // * 144115188075855872
{18, "34694469519536141888238489627838134765625"}, // * 288230376151711744
{18, "173472347597680709441192448139190673828125"}, // * 576460752303423488
{19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976
}
// Is the leading prefix of b lexicographically less than s?
func prefixIsLessThan(b []byte, s string) bool {
for i := 0; i < len(s); i++ {
if i >= len(b) {
return true
}
if b[i] != s[i] {
return b[i] < s[i]
}
}
return false
}
// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow.
func leftShift(a *decimal, k uint) {
delta := leftcheats[k].delta
if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) {
delta--
}
r := a.nd // read index
w := a.nd + delta // write index
// Pick up a digit, put down a digit.
var n uint
for r--; r >= 0; r-- {
n += (uint(a.d[r]) - '0') << k
quo := n / 10
rem := n - 10*quo
w--
if w < len(a.d) {
a.d[w] = byte(rem + '0')
} else if rem != 0 {
a.trunc = true
}
n = quo
}
// Put down extra digits.
for n > 0 {
quo := n / 10
rem := n - 10*quo
w--
if w < len(a.d) {
a.d[w] = byte(rem + '0')
} else if rem != 0 {
a.trunc = true
}
n = quo
}
a.nd += delta
if a.nd >= len(a.d) {
a.nd = len(a.d)
}
a.dp += delta
trim(a)
}
// Binary shift left (k > 0) or right (k < 0).
func (a *decimal) Shift(k int) {
switch {
case a.nd == 0:
// nothing to do: a == 0
case k > 0:
for k > maxShift {
leftShift(a, maxShift)
k -= maxShift
}
leftShift(a, uint(k))
case k < 0:
for k < -maxShift {
rightShift(a, maxShift)
k += maxShift
}
rightShift(a, uint(-k))
}
}
// If we chop a at nd digits, should we round up?
func shouldRoundUp(a *decimal, nd int) bool {
if nd < 0 || nd >= a.nd {
return false
}
if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even
// if we truncated, a little higher than what's recorded - always round up
if a.trunc {
return true
}
return nd > 0 && (a.d[nd-1]-'0')%2 != 0
}
// not halfway - digit tells all
return a.d[nd] >= '5'
}
// Round a to nd digits (or fewer).
// If nd is zero, it means we're rounding
// just to the left of the digits, as in
// 0.09 -> 0.1.
func (a *decimal) Round(nd int) {
if nd < 0 || nd >= a.nd {
return
}
if shouldRoundUp(a, nd) {
a.RoundUp(nd)
} else {
a.RoundDown(nd)
}
}
// Round a down to nd digits (or fewer).
func (a *decimal) RoundDown(nd int) {
if nd < 0 || nd >= a.nd {
return
}
a.nd = nd
trim(a)
}
// Round a up to nd digits (or fewer).
func (a *decimal) RoundUp(nd int) {
if nd < 0 || nd >= a.nd {
return
}
// round up
for i := nd - 1; i >= 0; i-- {
c := a.d[i]
if c < '9' { // can stop after this digit
a.d[i]++
a.nd = i + 1
return
}
}
// Number is all 9s.
// Change to single 1 with adjusted decimal point.
a.d[0] = '1'
a.nd = 1
a.dp++
}
// Extract integer part, rounded appropriately.
// No guarantees about overflow.
func (a *decimal) RoundedInteger() uint64 {
if a.dp > 20 {
return 0xFFFFFFFFFFFFFFFF
}
var i int
n := uint64(0)
for i = 0; i < a.dp && i < a.nd; i++ {
n = n*10 + uint64(a.d[i]-'0')
}
for ; i < a.dp; i++ {
n *= 10
}
if shouldRoundUp(a, a.dp) {
n++
}
return n
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,118 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Multiprecision decimal numbers.
// For floating-point formatting only; not general purpose.
// Only operations are assign and (binary) left/right shift.
// Can do binary floating point in multiprecision decimal precisely
// because 2 divides 10; cannot do decimal floating point
// in multiprecision binary precisely.
package decimal
type floatInfo struct {
mantbits uint
expbits uint
bias int
}
var float32info = floatInfo{23, 8, -127}
var float64info = floatInfo{52, 11, -1023}
// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits
// that will let the original floating point value be precisely reconstructed.
func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
// If mantissa is zero, the number is zero; stop now.
if mant == 0 {
d.nd = 0
return
}
// Compute upper and lower such that any decimal number
// between upper and lower (possibly inclusive)
// will round to the original floating point number.
// We may see at once that the number is already shortest.
//
// Suppose d is not denormal, so that 2^exp <= d < 10^dp.
// The closest shorter number is at least 10^(dp-nd) away.
// The lower/upper bounds computed below are at distance
// at most 2^(exp-mantbits).
//
// So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits),
// or equivalently log2(10)*(dp-nd) > exp-mantbits.
// It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32).
minexp := flt.bias + 1 // minimum possible exponent
if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) {
// The number is already shortest.
return
}
// d = mant << (exp - mantbits)
// Next highest floating point number is mant+1 << exp-mantbits.
// Our upper bound is halfway between, mant*2+1 << exp-mantbits-1.
upper := new(decimal)
upper.Assign(mant*2 + 1)
upper.Shift(exp - int(flt.mantbits) - 1)
// d = mant << (exp - mantbits)
// Next lowest floating point number is mant-1 << exp-mantbits,
// unless mant-1 drops the significant bit and exp is not the minimum exp,
// in which case the next lowest is mant*2-1 << exp-mantbits-1.
// Either way, call it mantlo << explo-mantbits.
// Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1.
var mantlo uint64
var explo int
if mant > 1<<flt.mantbits || exp == minexp {
mantlo = mant - 1
explo = exp
} else {
mantlo = mant*2 - 1
explo = exp - 1
}
lower := new(decimal)
lower.Assign(mantlo*2 + 1)
lower.Shift(explo - int(flt.mantbits) - 1)
// The upper and lower bounds are possible outputs only if
// the original mantissa is even, so that IEEE round-to-even
// would round to the original mantissa and not the neighbors.
inclusive := mant%2 == 0
// Now we can figure out the minimum number of digits required.
// Walk along until d has distinguished itself from upper and lower.
for i := 0; i < d.nd; i++ {
l := byte('0') // lower digit
if i < lower.nd {
l = lower.d[i]
}
m := d.d[i] // middle digit
u := byte('0') // upper digit
if i < upper.nd {
u = upper.d[i]
}
// Okay to round down (truncate) if lower has a different digit
// or if lower is inclusive and is exactly the result of rounding
// down (i.e., and we have reached the final digit of lower).
okdown := l != m || inclusive && i+1 == lower.nd
// Okay to round up if upper has a different digit and either upper
// is inclusive or upper is bigger than the result of rounding up.
okup := m != u && (inclusive || m+1 < u || i+1 < upper.nd)
// If it's okay to do either, then round to the nearest one.
// If it's okay to do only one, do it.
switch {
case okdown && okup:
d.Round(i + 1)
return
case okdown:
d.RoundDown(i + 1)
return
case okup:
d.RoundUp(i + 1)
return
}
}
}

View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 siddontang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,20 @@
// Package log supplies more advanced features than go orign log package.
//
// It supports log different level: trace, debug, info, warn, error, fatal.
//
// It also supports different log handlers which you can log to stdout, file, socket, etc...
//
// Use
//
// import "github.com/siddontang/go-log/log"
//
// //log with different level
// log.Info("hello world")
// log.Error("hello world")
//
// //create a logger with specified handler
// h := NewStreamHandler(os.Stdout)
// l := log.NewDefault(h)
// l.Info("hello world")
//
package log

View File

@ -0,0 +1,230 @@
package log
import (
"fmt"
"os"
"path"
"time"
)
// FileHandler writes log to a file.
type FileHandler struct {
fd *os.File
}
// NewFileHandler creates a FileHander
func NewFileHandler(fileName string, flag int) (*FileHandler, error) {
dir := path.Dir(fileName)
os.Mkdir(dir, 0777)
f, err := os.OpenFile(fileName, flag, 0755)
if err != nil {
return nil, err
}
h := new(FileHandler)
h.fd = f
return h, nil
}
// Write implements Handler interface
func (h *FileHandler) Write(b []byte) (n int, err error) {
return h.fd.Write(b)
}
// Close implements Handler interface
func (h *FileHandler) Close() error {
return h.fd.Close()
}
// RotatingFileHandler writes log a file, if file size exceeds maxBytes,
// it will backup current file and open a new one.
//
// max backup file number is set by backupCount, it will delete oldest if backups too many.
type RotatingFileHandler struct {
fd *os.File
fileName string
maxBytes int
curBytes int
backupCount int
}
// NewRotatingFileHandler creates a RotatingFileHandler
func NewRotatingFileHandler(fileName string, maxBytes int, backupCount int) (*RotatingFileHandler, error) {
dir := path.Dir(fileName)
os.MkdirAll(dir, 0777)
h := new(RotatingFileHandler)
if maxBytes <= 0 {
return nil, fmt.Errorf("invalid max bytes")
}
h.fileName = fileName
h.maxBytes = maxBytes
h.backupCount = backupCount
var err error
h.fd, err = os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, err
}
f, err := h.fd.Stat()
if err != nil {
return nil, err
}
h.curBytes = int(f.Size())
return h, nil
}
// Write implements Handler interface
func (h *RotatingFileHandler) Write(p []byte) (n int, err error) {
h.doRollover()
n, err = h.fd.Write(p)
h.curBytes += n
return
}
// Close implements Handler interface
func (h *RotatingFileHandler) Close() error {
if h.fd != nil {
return h.fd.Close()
}
return nil
}
func (h *RotatingFileHandler) doRollover() {
if h.curBytes < h.maxBytes {
return
}
f, err := h.fd.Stat()
if err != nil {
return
}
if h.maxBytes <= 0 {
return
} else if f.Size() < int64(h.maxBytes) {
h.curBytes = int(f.Size())
return
}
if h.backupCount > 0 {
h.fd.Close()
for i := h.backupCount - 1; i > 0; i-- {
sfn := fmt.Sprintf("%s.%d", h.fileName, i)
dfn := fmt.Sprintf("%s.%d", h.fileName, i+1)
os.Rename(sfn, dfn)
}
dfn := fmt.Sprintf("%s.1", h.fileName)
os.Rename(h.fileName, dfn)
h.fd, _ = os.OpenFile(h.fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
h.curBytes = 0
f, err := h.fd.Stat()
if err != nil {
return
}
h.curBytes = int(f.Size())
}
}
// TimeRotatingFileHandler writes log to a file,
// it will backup current and open a new one, with a period time you sepecified.
//
// refer: http://docs.python.org/2/library/logging.handlers.html.
// same like python TimedRotatingFileHandler.
type TimeRotatingFileHandler struct {
fd *os.File
baseName string
interval int64
suffix string
rolloverAt int64
}
// TimeRotating way
const (
WhenSecond = iota
WhenMinute
WhenHour
WhenDay
)
// NewTimeRotatingFileHandler creates a TimeRotatingFileHandler
func NewTimeRotatingFileHandler(baseName string, when int8, interval int) (*TimeRotatingFileHandler, error) {
dir := path.Dir(baseName)
os.MkdirAll(dir, 0777)
h := new(TimeRotatingFileHandler)
h.baseName = baseName
switch when {
case WhenSecond:
h.interval = 1
h.suffix = "2006-01-02_15-04-05"
case WhenMinute:
h.interval = 60
h.suffix = "2006-01-02_15-04"
case WhenHour:
h.interval = 3600
h.suffix = "2006-01-02_15"
case WhenDay:
h.interval = 3600 * 24
h.suffix = "2006-01-02"
default:
return nil, fmt.Errorf("invalid when_rotate: %d", when)
}
h.interval = h.interval * int64(interval)
var err error
h.fd, err = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, err
}
fInfo, _ := h.fd.Stat()
h.rolloverAt = fInfo.ModTime().Unix() + h.interval
return h, nil
}
func (h *TimeRotatingFileHandler) doRollover() {
//refer http://hg.python.org/cpython/file/2.7/Lib/logging/handlers.py
now := time.Now()
if h.rolloverAt <= now.Unix() {
fName := h.baseName + now.Format(h.suffix)
h.fd.Close()
e := os.Rename(h.baseName, fName)
if e != nil {
panic(e)
}
h.fd, _ = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
h.rolloverAt = time.Now().Unix() + h.interval
}
}
// Write implements Handler interface
func (h *TimeRotatingFileHandler) Write(b []byte) (n int, err error) {
h.doRollover()
return h.fd.Write(b)
}
// Close implements Handler interface
func (h *TimeRotatingFileHandler) Close() error {
return h.fd.Close()
}

View File

@ -0,0 +1,54 @@
package log
import (
"io"
)
//Handler writes logs to somewhere
type Handler interface {
Write(p []byte) (n int, err error)
Close() error
}
// StreamHandler writes logs to a specified io Writer, maybe stdout, stderr, etc...
type StreamHandler struct {
w io.Writer
}
// NewStreamHandler creates a StreamHandler
func NewStreamHandler(w io.Writer) (*StreamHandler, error) {
h := new(StreamHandler)
h.w = w
return h, nil
}
// Write implements Handler interface
func (h *StreamHandler) Write(b []byte) (n int, err error) {
return h.w.Write(b)
}
// Close implements Handler interface
func (h *StreamHandler) Close() error {
return nil
}
// NullHandler does nothing, it discards anything.
type NullHandler struct {
}
// NewNullHandler creates a NullHandler
func NewNullHandler() (*NullHandler, error) {
return new(NullHandler), nil
}
// // Write implements Handler interface
func (h *NullHandler) Write(b []byte) (n int, err error) {
return len(b), nil
}
// Close implements Handler interface
func (h *NullHandler) Close() {
}

View File

@ -0,0 +1,137 @@
package log
import (
"fmt"
"os"
)
var logger = NewDefault(newStdHandler())
// SetDefaultLogger changes the global logger
func SetDefaultLogger(l *Logger) {
logger = l
}
// SetLevel changes the logger level
func SetLevel(level Level) {
logger.SetLevel(level)
}
// SetLevelByName changes the logger level by name
func SetLevelByName(name string) {
logger.SetLevelByName(name)
}
// Fatal records the log with fatal level and exits
func Fatal(args ...interface{}) {
logger.Output(2, LevelFatal, fmt.Sprint(args...))
os.Exit(1)
}
// Fatalf records the log with fatal level and exits
func Fatalf(format string, args ...interface{}) {
logger.Output(2, LevelFatal, fmt.Sprintf(format, args...))
os.Exit(1)
}
// Fatalln records the log with fatal level and exits
func Fatalln(args ...interface{}) {
logger.Output(2, LevelFatal, fmt.Sprintln(args...))
os.Exit(1)
}
// Panic records the log with fatal level and panics
func Panic(args ...interface{}) {
msg := fmt.Sprint(args...)
logger.Output(2, LevelError, msg)
panic(msg)
}
// Panicf records the log with fatal level and panics
func Panicf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
logger.Output(2, LevelError, msg)
panic(msg)
}
// Panicln records the log with fatal level and panics
func Panicln(args ...interface{}) {
msg := fmt.Sprintln(args...)
logger.Output(2, LevelError, msg)
panic(msg)
}
// Print records the log with trace level
func Print(args ...interface{}) {
logger.Output(2, LevelTrace, fmt.Sprint(args...))
}
// Printf records the log with trace level
func Printf(format string, args ...interface{}) {
logger.Output(2, LevelTrace, fmt.Sprintf(format, args...))
}
// Println records the log with trace level
func Println(args ...interface{}) {
logger.Output(2, LevelTrace, fmt.Sprintln(args...))
}
// Debug records the log with debug level
func Debug(args ...interface{}) {
logger.Output(2, LevelDebug, fmt.Sprint(args...))
}
// Debugf records the log with debug level
func Debugf(format string, args ...interface{}) {
logger.Output(2, LevelDebug, fmt.Sprintf(format, args...))
}
// Debugln records the log with debug level
func Debugln(args ...interface{}) {
logger.Output(2, LevelDebug, fmt.Sprintln(args...))
}
// Error records the log with error level
func Error(args ...interface{}) {
logger.Output(2, LevelError, fmt.Sprint(args...))
}
// Errorf records the log with error level
func Errorf(format string, args ...interface{}) {
logger.Output(2, LevelError, fmt.Sprintf(format, args...))
}
// Errorln records the log with error level
func Errorln(args ...interface{}) {
logger.Output(2, LevelError, fmt.Sprintln(args...))
}
// Info records the log with info level
func Info(args ...interface{}) {
logger.Output(2, LevelInfo, fmt.Sprint(args...))
}
// Infof records the log with info level
func Infof(format string, args ...interface{}) {
logger.Output(2, LevelInfo, fmt.Sprintf(format, args...))
}
// Infoln records the log with info level
func Infoln(args ...interface{}) {
logger.Output(2, LevelInfo, fmt.Sprintln(args...))
}
// Warn records the log with warn level
func Warn(args ...interface{}) {
logger.Output(2, LevelWarn, fmt.Sprint(args...))
}
// Warnf records the log with warn level
func Warnf(format string, args ...interface{}) {
logger.Output(2, LevelWarn, fmt.Sprintf(format, args...))
}
// Warnln records the log with warn level
func Warnln(args ...interface{}) {
logger.Output(2, LevelWarn, fmt.Sprintln(args...))
}

View File

@ -0,0 +1,340 @@
package log
import (
"fmt"
"os"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/siddontang/go-log/loggers"
)
const (
timeFormat = "2006/01/02 15:04:05"
maxBufPoolSize = 16
)
// Logger flag
const (
Ltime = 1 << iota // time format "2006/01/02 15:04:05"
Lfile // file.go:123
Llevel // [Trace|Debug|Info...]
)
// Level type
type Level int
// Log level, from low to high, more high means more serious
const (
LevelTrace Level = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
LevelFatal
)
// String returns level String
func (l Level) String() string {
switch l {
case LevelTrace:
return "trace"
case LevelDebug:
return "debug"
case LevelInfo:
return "info"
case LevelWarn:
return "warn"
case LevelError:
return "error"
case LevelFatal:
return "fatal"
}
// return default info
return "info"
}
// Logger is the logger to record log
type Logger struct {
// TODO: support logger.Contextual
loggers.Advanced
sync.Mutex
level Level
flag int
handler Handler
quit chan struct{}
msg chan []byte
bufs [][]byte
}
// New creates a logger with specified handler and flag
func New(handler Handler, flag int) *Logger {
var l = new(Logger)
l.level = LevelInfo
l.handler = handler
l.flag = flag
l.quit = make(chan struct{})
l.msg = make(chan []byte, 1024)
l.bufs = make([][]byte, 0, 16)
go l.run()
return l
}
// NewDefault creates default logger with specified handler and flag: Ltime|Lfile|Llevel
func NewDefault(handler Handler) *Logger {
return New(handler, Ltime|Lfile|Llevel)
}
func newStdHandler() *StreamHandler {
h, _ := NewStreamHandler(os.Stdout)
return h
}
func (l *Logger) run() {
for {
select {
case msg := <-l.msg:
l.handler.Write(msg)
l.putBuf(msg)
case <-l.quit:
l.handler.Close()
}
}
}
func (l *Logger) popBuf() []byte {
l.Lock()
var buf []byte
if len(l.bufs) == 0 {
buf = make([]byte, 0, 1024)
} else {
buf = l.bufs[len(l.bufs)-1]
l.bufs = l.bufs[0 : len(l.bufs)-1]
}
l.Unlock()
return buf
}
func (l *Logger) putBuf(buf []byte) {
l.Lock()
if len(l.bufs) < maxBufPoolSize {
buf = buf[0:0]
l.bufs = append(l.bufs, buf)
}
l.Unlock()
}
// Close closes the logger
func (l *Logger) Close() {
if l.quit == nil {
return
}
close(l.quit)
l.quit = nil
}
// SetLevel sets log level, any log level less than it will not log
func (l *Logger) SetLevel(level Level) {
l.level = level
}
// SetLevelByName sets log level by name
func (l *Logger) SetLevelByName(name string) {
level := LevelInfo
switch strings.ToLower(name) {
case "trace":
level = LevelTrace
case "debug":
level = LevelDebug
case "warn", "warning":
level = LevelWarn
case "error":
level = LevelError
case "fatal":
level = LevelFatal
default:
level = LevelInfo
}
l.SetLevel(level)
}
// Output records the log with special callstack depth and log level.
func (l *Logger) Output(callDepth int, level Level, msg string) {
if l.level > level {
return
}
buf := l.popBuf()
if l.flag&Ltime > 0 {
now := time.Now().Format(timeFormat)
buf = append(buf, '[')
buf = append(buf, now...)
buf = append(buf, "] "...)
}
if l.flag&Llevel > 0 {
buf = append(buf, '[')
buf = append(buf, level.String()...)
buf = append(buf, "] "...)
}
if l.flag&Lfile > 0 {
_, file, line, ok := runtime.Caller(callDepth)
if !ok {
file = "???"
line = 0
} else {
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
file = file[i+1:]
break
}
}
}
buf = append(buf, file...)
buf = append(buf, ':')
buf = strconv.AppendInt(buf, int64(line), 10)
buf = append(buf, ' ')
}
buf = append(buf, msg...)
if len(msg) == 0 || msg[len(msg)-1] != '\n' {
buf = append(buf, '\n')
}
l.msg <- buf
}
// Fatal records the log with fatal level and exits
func (l *Logger) Fatal(args ...interface{}) {
l.Output(2, LevelFatal, fmt.Sprint(args...))
os.Exit(1)
}
// Fatalf records the log with fatal level and exits
func (l *Logger) Fatalf(format string, args ...interface{}) {
l.Output(2, LevelFatal, fmt.Sprintf(format, args...))
os.Exit(1)
}
// Fatalln records the log with fatal level and exits
func (l *Logger) Fatalln(args ...interface{}) {
l.Output(2, LevelFatal, fmt.Sprintln(args...))
os.Exit(1)
}
// Panic records the log with fatal level and panics
func (l *Logger) Panic(args ...interface{}) {
msg := fmt.Sprint(args...)
l.Output(2, LevelError, msg)
panic(msg)
}
// Panicf records the log with fatal level and panics
func (l *Logger) Panicf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
l.Output(2, LevelError, msg)
panic(msg)
}
// Panicln records the log with fatal level and panics
func (l *Logger) Panicln(args ...interface{}) {
msg := fmt.Sprintln(args...)
l.Output(2, LevelError, msg)
panic(msg)
}
// Print records the log with trace level
func (l *Logger) Print(args ...interface{}) {
l.Output(2, LevelTrace, fmt.Sprint(args...))
}
// Printf records the log with trace level
func (l *Logger) Printf(format string, args ...interface{}) {
l.Output(2, LevelTrace, fmt.Sprintf(format, args...))
}
// Println records the log with trace level
func (l *Logger) Println(args ...interface{}) {
l.Output(2, LevelTrace, fmt.Sprintln(args...))
}
// Debug records the log with debug level
func (l *Logger) Debug(args ...interface{}) {
l.Output(2, LevelDebug, fmt.Sprint(args...))
}
// Debugf records the log with debug level
func (l *Logger) Debugf(format string, args ...interface{}) {
l.Output(2, LevelDebug, fmt.Sprintf(format, args...))
}
// Debugln records the log with debug level
func (l *Logger) Debugln(args ...interface{}) {
l.Output(2, LevelDebug, fmt.Sprintln(args...))
}
// Error records the log with error level
func (l *Logger) Error(args ...interface{}) {
l.Output(2, LevelError, fmt.Sprint(args...))
}
// Errorf records the log with error level
func (l *Logger) Errorf(format string, args ...interface{}) {
l.Output(2, LevelError, fmt.Sprintf(format, args...))
}
// Errorln records the log with error level
func (l *Logger) Errorln(args ...interface{}) {
l.Output(2, LevelError, fmt.Sprintln(args...))
}
// Info records the log with info level
func (l *Logger) Info(args ...interface{}) {
l.Output(2, LevelInfo, fmt.Sprint(args...))
}
// Infof records the log with info level
func (l *Logger) Infof(format string, args ...interface{}) {
l.Output(2, LevelInfo, fmt.Sprintf(format, args...))
}
// Infoln records the log with info level
func (l *Logger) Infoln(args ...interface{}) {
l.Output(2, LevelInfo, fmt.Sprintln(args...))
}
// Warn records the log with warn level
func (l *Logger) Warn(args ...interface{}) {
l.Output(2, LevelWarn, fmt.Sprint(args...))
}
// Warnf records the log with warn level
func (l *Logger) Warnf(format string, args ...interface{}) {
l.Output(2, LevelWarn, fmt.Sprintf(format, args...))
}
// Warnln records the log with warn level
func (l *Logger) Warnln(args ...interface{}) {
l.Output(2, LevelWarn, fmt.Sprintln(args...))
}

View File

@ -0,0 +1,68 @@
// MIT License
// Copyright (c) 2017 Birkir A. Barkarson
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
package loggers
// Standard is the interface used by Go's standard library's log package.
type Standard interface {
Fatal(args ...interface{})
Fatalf(format string, args ...interface{})
Fatalln(args ...interface{})
Panic(args ...interface{})
Panicf(format string, args ...interface{})
Panicln(args ...interface{})
Print(args ...interface{})
Printf(format string, args ...interface{})
Println(args ...interface{})
}
// Advanced is an interface with commonly used log level methods.
type Advanced interface {
Standard
Debug(args ...interface{})
Debugf(format string, args ...interface{})
Debugln(args ...interface{})
Error(args ...interface{})
Errorf(format string, args ...interface{})
Errorln(args ...interface{})
Info(args ...interface{})
Infof(format string, args ...interface{})
Infoln(args ...interface{})
Warn(args ...interface{})
Warnf(format string, args ...interface{})
Warnln(args ...interface{})
}
// Contextual is an interface that allows context addition to a log statement before
// calling the final print (message/level) method.
type Contextual interface {
Advanced
WithField(key string, value interface{}) Advanced
WithFields(fields ...interface{}) Advanced
}

View File

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2014 siddontang
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,25 @@
BSON library for Go
Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,27 @@
Copyright (c) 2011 The LevelDB-Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,27 @@
package hack
import (
"reflect"
"unsafe"
)
// no copy to change slice to string
// use your own risk
func String(b []byte) (s string) {
pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
pstring.Data = pbytes.Data
pstring.Len = pbytes.Len
return
}
// no copy to change string to slice
// use your own risk
func Slice(s string) (b []byte) {
pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
pbytes.Data = pstring.Data
pbytes.Len = pstring.Len
pbytes.Cap = pstring.Len
return
}

Some files were not shown because too many files have changed in this diff Show More