diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8301e70..c0c2464 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -11,7 +11,7 @@ Related issue: https://github.com/github/gh-ost/issues/0123456789 ### Description -This PR [briefly explain what is does] +This PR [briefly explain what it does] > In case this PR introduced Go code changes: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..cf518ee --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,25 @@ +name: CI + +on: [pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Go 1.15 + uses: actions/setup-go@v1 + with: + go-version: 1.15 + + - name: Build + run: script/cibuild + + - name: Upload gh-ost binary artifact + uses: actions/upload-artifact@v1 + with: + name: gh-ost + path: bin/gh-ost diff --git a/.github/workflows/replica-tests.yml b/.github/workflows/replica-tests.yml new file mode 100644 index 0000000..31e2052 --- /dev/null +++ b/.github/workflows/replica-tests.yml @@ -0,0 +1,19 @@ +name: migration tests + +on: [pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Go 1.14 + uses: actions/setup-go@v1 + with: + go-version: 1.14 + + - name: migration tests + run: script/cibuild-gh-ost-replica-tests diff --git a/.gitignore b/.gitignore index 63f0df9..605546d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ /bin/ /libexec/ /.vendor/ +.idea/ diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 079e425..0000000 --- a/.travis.yml +++ /dev/null @@ -1,20 +0,0 @@ -# http://docs.travis-ci.com/user/languages/go/ -language: go - -go: 1.8 - -os: - - linux - -env: -- MYSQL_USER=root - -before_install: - - mysql -e 'CREATE DATABASE IF NOT EXISTS test;' - -install: true - -script: script/cibuild - -notifications: - email: false diff --git a/Dockerfile.packaging b/Dockerfile.packaging new file mode 100644 index 0000000..092fade --- /dev/null +++ b/Dockerfile.packaging @@ -0,0 +1,22 @@ +# + +FROM golang:1.15.6 + +RUN apt-get update +RUN apt-get install -y ruby ruby-dev rubygems build-essential +RUN gem install --no-ri --no-rdoc fpm +ENV GOPATH=/tmp/go + +RUN apt-get install -y curl +RUN apt-get install -y rsync +RUN apt-get install -y gcc +RUN apt-get install -y g++ +RUN apt-get install -y bash +RUN apt-get install -y git +RUN apt-get install -y tar +RUN apt-get install -y rpm + +RUN mkdir -p $GOPATH/src/github.com/github/gh-ost +WORKDIR $GOPATH/src/github.com/github/gh-ost +COPY . . +RUN bash build.sh diff --git a/Dockerfile.test b/Dockerfile.test new file mode 100644 index 0000000..ceb46bf --- /dev/null +++ b/Dockerfile.test @@ -0,0 +1,11 @@ +FROM golang:1.15.6 +LABEL maintainer="github@github.com" + +RUN apt-get update +RUN apt-get install -y lsb-release +RUN rm -rf /var/lib/apt/lists/* + +COPY . /go/src/github.com/github/gh-ost +WORKDIR /go/src/github.com/github/gh-ost + +CMD ["script/test"] diff --git a/README.md b/README.md index 0d896bd..d496e08 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ Also see: - [the fine print](doc/the-fine-print.md) - [Community questions](https://github.com/github/gh-ost/issues?q=label%3Aquestion) - [Using `gh-ost` on AWS RDS](doc/rds.md) +- [Using `gh-ost` on Azure Database for MySQL](doc/azure.md) ## What's in a name? @@ -94,7 +95,7 @@ Please see [Coding gh-ost](doc/coding-ghost.md) for a guide to getting started d [Download latest release here](https://github.com/github/gh-ost/releases/latest) -`gh-ost` is a Go project; it is built with Go `1.8` (though `1.7` should work as well). To build on your own, use either: +`gh-ost` is a Go project; it is built with Go `1.14` and above. To build on your own, use either: - [script/build](https://github.com/github/gh-ost/blob/master/script/build) - this is the same build script used by CI hence the authoritative; artifact is `./bin/gh-ost` binary. - [build.sh](https://github.com/github/gh-ost/blob/master/build.sh) for building `tar.gz` artifacts in `/tmp/gh-ost` @@ -107,3 +108,6 @@ Generally speaking, `master` branch is stable, but only [releases](https://githu - [@ggunson](https://github.com/ggunson) - [@tomkrouper](https://github.com/tomkrouper) - [@shlomi-noach](https://github.com/shlomi-noach) +- [@jessbreckenridge](https://github.com/jessbreckenridge) +- [@gtowey](https://github.com/gtowey) +- [@timvaillancourt](https://github.com/timvaillancourt) diff --git a/RELEASE_VERSION b/RELEASE_VERSION index 4ad595c..9084fa2 100644 --- a/RELEASE_VERSION +++ b/RELEASE_VERSION @@ -1 +1 @@ -1.0.42 +1.1.0 diff --git a/build.sh b/build.sh index 3e1ce6f..b5d4659 100755 --- a/build.sh +++ b/build.sh @@ -2,36 +2,70 @@ # # -RELEASE_VERSION=$(cat RELEASE_VERSION) +RELEASE_VERSION= +buildpath= -function build { - osname=$1 - osshort=$2 - GOOS=$3 - GOARCH=$4 - - echo "Building ${osname} binary" - export GOOS - export GOARCH - go build -ldflags "$ldflags" -o $buildpath/$target go/cmd/gh-ost/main.go - - if [ $? -ne 0 ]; then - echo "Build failed for ${osname}" - exit 1 - fi - - (cd $buildpath && tar cfz ./gh-ost-binary-${osshort}-${timestamp}.tar.gz $target) +function setuptree() { + b=$( mktemp -d $buildpath/gh-ostXXXXXX ) || return 1 + mkdir -p $b/gh-ost + mkdir -p $b/gh-ost/usr/bin + echo $b } -buildpath=/tmp/gh-ost -target=gh-ost -timestamp=$(date "+%Y%m%d%H%M%S") -ldflags="-X main.AppVersion=${RELEASE_VERSION}" -export GO15VENDOREXPERIMENT=1 +function build { + osname=$1 + osshort=$2 + GOOS=$3 + GOARCH=$4 -mkdir -p ${buildpath} -build macOS osx darwin amd64 -build GNU/Linux linux linux amd64 + if ! go version | egrep -q 'go(1\.1[456])' ; then + echo "go version must be 1.14 or above" + exit 1 + fi -echo "Binaries found in:" -ls -1 $buildpath/gh-ost-binary*${timestamp}.tar.gz + echo "Building ${osname} binary" + export GOOS + export GOARCH + go build -ldflags "$ldflags" -o $buildpath/$target go/cmd/gh-ost/main.go + + if [ $? -ne 0 ]; then + echo "Build failed for ${osname}" + exit 1 + fi + + (cd $buildpath && tar cfz ./gh-ost-binary-${osshort}-${timestamp}.tar.gz $target) + + if [ "$GOOS" == "linux" ] ; then + echo "Creating Distro full packages" + builddir=$(setuptree) + cp $buildpath/$target $builddir/gh-ost/usr/bin + cd $buildpath + fpm -v "${RELEASE_VERSION}" --epoch 1 -f -s dir -n gh-ost -m 'shlomi-noach ' --description "GitHub's Online Schema Migrations for MySQL " --url "https://github.com/github/gh-ost" --vendor "GitHub" --license "Apache 2.0" -C $builddir/gh-ost --prefix=/ -t rpm . + fpm -v "${RELEASE_VERSION}" --epoch 1 -f -s dir -n gh-ost -m 'shlomi-noach ' --description "GitHub's Online Schema Migrations for MySQL " --url "https://github.com/github/gh-ost" --vendor "GitHub" --license "Apache 2.0" -C $builddir/gh-ost --prefix=/ -t deb --deb-no-default-config-files . + fi +} + +main() { + if [ -z "${RELEASE_VERSION}" ] ; then + RELEASE_VERSION=$(git describe --abbrev=0 --tags | tr -d 'v') + fi + if [ -z "${RELEASE_VERSION}" ] ; then + RELEASE_VERSION=$(cat RELEASE_VERSION) + fi + + + buildpath=/tmp/gh-ost-release + target=gh-ost + timestamp=$(date "+%Y%m%d%H%M%S") + ldflags="-X main.AppVersion=${RELEASE_VERSION}" + + mkdir -p ${buildpath} + rm -rf ${buildpath:?}/* + build GNU/Linux linux linux amd64 + # build macOS osx darwin amd64 + + echo "Binaries found in:" + find $buildpath/gh-ost* -type f -maxdepth 1 +} + +main "$@" diff --git a/doc/azure.md b/doc/azure.md new file mode 100644 index 0000000..f544f37 --- /dev/null +++ b/doc/azure.md @@ -0,0 +1,26 @@ +`gh-ost` has been updated to work with Azure Database for MySQL however due to GitHub does not use it, this documentation is community driven so if you find a bug please [open an issue][new_issue]! + +# Azure Database for MySQL + +## Limitations + +- `gh-ost` runs should be setup use [`--assume-rbr`][assume_rbr_docs] and use `binlog_row_image=FULL`. +- Azure Database for MySQL does not use same user name suffix for master and replica, so master host, user and password need to be pointed out. + +## Step +1. Change the replica server's `binlog_row_image` from `MINIMAL` to `FULL`. See [guide](https://docs.microsoft.com/en-us/azure/mysql/howto-server-parameters) on Azure document. +2. Use your `gh-ost` always with additional 5 parameter +```{bash} +gh-ost \ +--azure \ +--assume-master-host=master-server-dns-name \ +--master-user="master-user-name" \ +--master-password="master-password" \ +--assume-rbr \ +[-- other paramters you need] +``` + + +[new_issue]: https://github.com/github/gh-ost/issues/new +[assume_rbr_docs]: https://github.com/github/gh-ost/blob/master/doc/command-line-flags.md#assume-rbr +[migrate_test_on_replica_docs]: https://github.com/github/gh-ost/blob/master/doc/cheatsheet.md#c-migratetest-on-replica \ No newline at end of file diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index cea7333..22dccbd 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -2,6 +2,14 @@ A more in-depth discussion of various `gh-ost` command line flags: implementation, implication, use cases. +### aliyun-rds + +Add this flag when executing on Aliyun RDS. + +### azure + +Add this flag when executing on Azure Database for MySQL. + ### allow-master-master See [`--assume-master-host`](#assume-master-host). @@ -14,7 +22,7 @@ If, for some reason, you do not wish `gh-ost` to connect to a replica, you may c ### approve-renamed-columns -When your migration issues a column rename (`change column old_name new_name ...`) `gh-ost` analyzes the statement to try an associate the old column name with new column name. Otherwise the new structure may also look like some column was dropped and another was added. +When your migration issues a column rename (`change column old_name new_name ...`) `gh-ost` analyzes the statement to try and associate the old column name with new column name. Otherwise the new structure may also look like some column was dropped and another was added. `gh-ost` will print out what it thinks the _rename_ implied, but will not issue the migration unless you provide with `--approve-renamed-columns`. @@ -65,6 +73,10 @@ This is somewhat similar to a Nagios `n`-times test, where `n` in our case is al Optional. Default is `safe`. See more discussion in [`cut-over`](cut-over.md) +### cut-over-lock-timeout-seconds + +Default `3`. Max number of seconds to hold locks on tables while attempting to cut-over (retry attempted when lock exceeds timeout). + ### discard-foreign-keys **Danger**: this flag will _silently_ discard any foreign keys existing on your table. @@ -82,7 +94,7 @@ The `--dml-batch-size` flag controls the size of the batched write. Allowed valu Why is this behavior configurable? Different workloads have different characteristics. Some workloads have very large writes, such that aggregating even `50` writes into a transaction makes for a significant transaction size. On other workloads write rate is high such that one just can't allow for a hundred more syncs to disk per second. The default value of `10` is a modest compromise that should probably work very well for most workloads. Your mileage may vary. -Noteworthy is that setting `--dml-batch-size` to higher value _does not_ mean `gh-ost` blocks or waits on writes. The batch size is an upper limit on transaction size, not a minimal one. If `gh-ost` doesn't have "enough" events in the pipe, it does not wait on the binary log, it just writes what it already has. This conveniently suggests that if write load is light enough for `gh-ost` to only see a few events in the binary log at a given time, then it is also light neough for `gh-ost` to apply a fraction of the batch size. +Noteworthy is that setting `--dml-batch-size` to higher value _does not_ mean `gh-ost` blocks or waits on writes. The batch size is an upper limit on transaction size, not a minimal one. If `gh-ost` doesn't have "enough" events in the pipe, it does not wait on the binary log, it just writes what it already has. This conveniently suggests that if write load is light enough for `gh-ost` to only see a few events in the binary log at a given time, then it is also light enough for `gh-ost` to apply a fraction of the batch size. ### exact-rowcount @@ -103,6 +115,22 @@ While the ongoing estimated number of rows is still heuristic, it's almost exact Without this parameter, migration is a _noop_: testing table creation and validity of migration, but not touching data. +### force-named-cut-over + +If given, a `cut-over` command must name the migrated table, or else ignored. + +### force-named-panic + +If given, a `panic` command must name the migrated table, or else ignored. + +### force-table-names + +Table name prefix to be used on the temporary tables. + +### gcp + +Add this flag when executing on a 1st generation Google Cloud Platform (GCP). + ### heartbeat-interval-millis Default 100. See [`subsecond-lag`](subsecond-lag.md) for details. @@ -117,6 +145,10 @@ We think `gh-ost` should not take chances or make assumptions about the user's t See [`initially-drop-ghost-table`](#initially-drop-ghost-table) +### initially-drop-socket-file + +Default False. Should `gh-ost` forcibly delete an existing socket file. Be careful: this might drop the socket file of a running migration! + ### max-lag-millis On a replication topology, this is perhaps the most important migration throttling factor: the maximum lag allowed for migration to work. If lag exceeds this value, migration throttles. @@ -133,7 +165,7 @@ List of metrics and threshold values; topping the threshold of any will cause th ### migrate-on-replica -Typically `gh-ost` is used to migrate tables on a master. If you wish to only perform the migration in full on a replica, connect `gh-ost` to said replica and pass `--migrate-on-replica`. `gh-ost` will briefly connect to the master but other issue no changes on the master. Migration will be fully executed on the replica, while making sure to maintain a small replication lag. +Typically `gh-ost` is used to migrate tables on a master. If you wish to only perform the migration in full on a replica, connect `gh-ost` to said replica and pass `--migrate-on-replica`. `gh-ost` will briefly connect to the master but otherwise will make no changes on the master. Migration will be fully executed on the replica, while making sure to maintain a small replication lag. ### postpone-cut-over-flag-file @@ -151,16 +183,44 @@ See also: [`concurrent-migrations`](cheatsheet.md#concurrent-migrations) on the ### skip-foreign-key-checks -By default `gh-ost` verifies no foreign keys exist on the migrated table. On servers with large number of tables this check can take a long time. If you're absolutely certain no foreign keys exist (table does not referenece other table nor is referenced by other tables) and wish to save the check time, provide with `--skip-foreign-key-checks`. +By default `gh-ost` verifies no foreign keys exist on the migrated table. On servers with large number of tables this check can take a long time. If you're absolutely certain no foreign keys exist (table does not reference other table nor is referenced by other tables) and wish to save the check time, provide with `--skip-foreign-key-checks`. + +### skip-strict-mode + +By default `gh-ost` enforces STRICT_ALL_TABLES sql_mode as a safety measure. In some cases this changes the behaviour of other modes (namely ERROR_FOR_DIVISION_BY_ZERO, NO_ZERO_DATE, and NO_ZERO_IN_DATE) which may lead to errors during migration. Use `--skip-strict-mode` to explicitly tell `gh-ost` not to enforce this. **Danger** This may have some unexpected disastrous side effects. ### skip-renamed-columns See [`approve-renamed-columns`](#approve-renamed-columns) +### ssl + +By default `gh-ost` does not use ssl/tls connections to the database servers when performing migrations. This flag instructs `gh-ost` to use encrypted connections. If enabled, `gh-ost` will use the system's ca certificate pool for server certificate verification. If a different certificate is needed for server verification, see `--ssl-ca`. If you wish to skip server verification, but still use encrypted connections, use with `--ssl-allow-insecure`. + +### ssl-allow-insecure + +Allows `gh-ost` to connect to the MySQL servers using encrypted connections, but without verifying the validity of the certificate provided by the server during the connection. Requires `--ssl`. + +### ssl-ca + +`--ssl-ca=/path/to/ca-cert.pem`: ca certificate file (in PEM format) to use for server certificate verification. If specified, the default system ca cert pool will not be used for verification, only the ca cert provided here. Requires `--ssl`. + +### ssl-cert + +`--ssl-cert=/path/to/ssl-cert.crt`: SSL public key certificate file (in PEM format). + +### ssl-key + +`--ssl-key=/path/to/ssl-key.key`: SSL private key file (in PEM format). + ### test-on-replica Issue the migration on a replica; do not modify data on master. Useful for validating, testing and benchmarking. See [`testing-on-replica`](testing-on-replica.md) +### test-on-replica-skip-replica-stop + +Default `False`. When `--test-on-replica` is enabled, do not issue commands stop replication (requires `--test-on-replica`). + ### throttle-control-replicas Provide a command delimited list of replicas; `gh-ost` will throttle when any of the given replicas lag beyond [`--max-lag-millis`](#max-lag-millis). The list can be queried and updated dynamically via [interactive commands](interactive-commands.md) diff --git a/doc/hooks.md b/doc/hooks.md index 43ddb7e..4c49c85 100644 --- a/doc/hooks.md +++ b/doc/hooks.md @@ -65,10 +65,15 @@ The following variables are available on all hooks: - `GH_OST_ELAPSED_COPY_SECONDS` - row-copy time (excluding startup, row-count and postpone time) - `GH_OST_ESTIMATED_ROWS` - estimated total rows in table - `GH_OST_COPIED_ROWS` - number of rows copied by `gh-ost` +- `GH_OST_INSPECTED_LAG` - lag in seconds (floating point) of inspected server +- `GH_OST_PROGRESS` - progress pct ([0..100], floating point) of migration - `GH_OST_MIGRATED_HOST` - `GH_OST_INSPECTED_HOST` - `GH_OST_EXECUTING_HOST` - `GH_OST_HOOKS_HINT` - copy of `--hooks-hint` value +- `GH_OST_HOOKS_HINT_OWNER` - copy of `--hooks-hint-owner` value +- `GH_OST_HOOKS_HINT_TOKEN` - copy of `--hooks-hint-token` value +- `GH_OST_DRY_RUN` - whether or not the `gh-ost` run is a dry run The following variable are available on particular hooks: diff --git a/doc/interactive-commands.md b/doc/interactive-commands.md index 5390690..591aa49 100644 --- a/doc/interactive-commands.md +++ b/doc/interactive-commands.md @@ -43,7 +43,7 @@ Both interfaces may serve at the same time. Both respond to simple text command, ### Querying for data -For commands that accept an argumetn as value, pass `?` (question mark) to _get_ current value rather than _set_ a new one. +For commands that accept an argument as value, pass `?` (question mark) to _get_ current value rather than _set_ a new one. ### Examples diff --git a/doc/questions.md b/doc/questions.md index be6eab0..7585bc7 100644 --- a/doc/questions.md +++ b/doc/questions.md @@ -28,3 +28,9 @@ It is therefore unlikely that `gh-ost` will support this behavior. Yes. TL;DR if running all on same replica/master, make sure to provide `--replica-server-id`. [Read more](cheatsheet.md#concurrent-migrations) # Why + +### Why Is the "Connect to Replica" mode preferred? + +To avoid placing extra load on the master. `gh-ost` connects as a replication client. Each additional replica adds some load to the master. + +To monitor replication lag from a replica. This makes the replication lag throttle, `--max-lag-millis`, more representative of the lag experienced by other replicas following the master (perhaps N levels deep in a tree of replicas). diff --git a/doc/rds.md b/doc/rds.md index 889d480..da59abb 100644 --- a/doc/rds.md +++ b/doc/rds.md @@ -1,4 +1,4 @@ -`gh-ost` has been updated to work with Amazon RDS however due to GitHub not relying using AWS for databases, this documentation is community driven so if you find a bug please [open an issue][new_issue]! +`gh-ost` has been updated to work with Amazon RDS however due to GitHub not using AWS for databases, this documentation is community driven so if you find a bug please [open an issue][new_issue]! # Amazon RDS @@ -26,6 +26,14 @@ If you use `pt-table-checksum` as a part of your data integrity checks, you migh This tool requires binlog_format=STATEMENT, but the current binlog_format is set to ROW and an error occurred while attempting to change it. If running MySQL 5.1.29 or newer, setting binlog_format requires the SUPER privilege. You will need to manually set binlog_format to 'STATEMENT' before running this tool. ``` +#### Binlog filtering + +In Aurora, the [binlog filtering feature][aws_replication_docs_bin_log_filtering] is enabled by default. This becomes an issue when gh-ost tries to do the cut-over, because gh-ost waits for an entry in the binlog to proceed but this entry will never end up in the binlog because it gets filtered out by the binlog filtering feature. +You need to turn this feature off during the migration process. +Set the `aurora_enable_repl_bin_log_filtering` parameter to 0 in the Parameter Group for your cluster. +When the migration is done, set it back to 1 (default). + + #### Preflight checklist Before trying to run any `gh-ost` migrations you will want to confirm the following: @@ -35,6 +43,7 @@ Before trying to run any `gh-ost` migrations you will want to confirm the follow - [ ] Executing `SHOW SLAVE STATUS\G` on your replica cluster displays the correct master host, binlog position, etc. - [ ] Database backup retention is greater than 1 day to enable binlogs - [ ] You have setup [`hooks`][ghost_hooks] to issue RDS procedures for stopping and starting replication. (see [github/gh-ost#163][ghost_rds_issue_tracking] for examples) +- [ ] The parameter `aurora_enable_repl_bin_log_filtering` is set to 0 [new_issue]: https://github.com/github/gh-ost/issues/new [assume_rbr_docs]: https://github.com/github/gh-ost/blob/master/doc/command-line-flags.md#assume-rbr @@ -43,3 +52,4 @@ Before trying to run any `gh-ost` migrations you will want to confirm the follow [percona_toolkit_patch]: https://github.com/jacobbednarz/percona-toolkit/commit/0271ba6a094da446a5e5bb8d99b5c26f1777f2b9 [ghost_hooks]: https://github.com/github/gh-ost/blob/master/doc/hooks.md [ghost_rds_issue_tracking]: https://github.com/github/gh-ost/issues/163 +[aws_replication_docs_bin_log_filtering]: https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/AuroraMySQL.Replication.html#AuroraMySQL.Replication.Performance \ No newline at end of file diff --git a/doc/requirements-and-limitations.md b/doc/requirements-and-limitations.md index e12f3f0..e09ae4f 100644 --- a/doc/requirements-and-limitations.md +++ b/doc/requirements-and-limitations.md @@ -22,14 +22,10 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th ### Limitations -- Foreign keys not supported. They may be supported in the future, to some extent. +- Foreign key constraints are not supported. They may be supported in the future, to some extent. - Triggers are not supported. They may be supported in the future. -- MySQL 5.7 generated columns are not supported. They may be supported in the future. - -- MySQL 5.7 `POINT` column type is not supported. - - MySQL 5.7 `JSON` columns are supported but not as part of `PRIMARY KEY` - The two _before_ & _after_ tables must share a `PRIMARY KEY` or other `UNIQUE KEY`. This key will be used by `gh-ost` to iterate through the table rows when copying. [Read more](shared-key.md) @@ -42,13 +38,18 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - It is not allowed to migrate a table where another table exists with same name and different upper/lower case. - For example, you may not migrate `MyTable` if another table called `MYtable` exists in the same schema. -- Amazon RDS works, but has it's own [limitations](rds.md). -- Google Cloud SQL is currently not supported +- Amazon RDS works, but has its own [limitations](rds.md). +- Google Cloud SQL works, `--gcp` flag required. +- Aliyun RDS works, `--aliyun-rds` flag required. +- Azure Database for MySQL works, `--azure` flag required, and have detailed document about it. (azure.md) - Multisource is not supported when migrating via replica. It _should_ work (but never tested) when connecting directly to master (`--allow-on-master`) - Master-master setup is only supported in active-passive setup. Active-active (where table is being written to on both masters concurrently) is unsupported. It may be supported in the future. -- If you have en `enum` field as part of your migration key (typically the `PRIMARY KEY`), migration performance will be degraded and potentially bad. [Read more](https://github.com/github/gh-ost/pull/277#issuecomment-254811520) +- If you have an `enum` field as part of your migration key (typically the `PRIMARY KEY`), migration performance will be degraded and potentially bad. [Read more](https://github.com/github/gh-ost/pull/277#issuecomment-254811520) - Migrating a `FEDERATED` table is unsupported and is irrelevant to the problem `gh-ost` tackles. + +- [Encrypted binary logs](https://www.percona.com/blog/2018/03/08/binlog-encryption-percona-server-mysql/) are not supported. +- `ALTER TABLE ... RENAME TO some_other_name` is not supported (and you shouldn't use `gh-ost` for such a trivial operation). diff --git a/doc/shared-key.md b/doc/shared-key.md index 480ad97..c7f24cc 100644 --- a/doc/shared-key.md +++ b/doc/shared-key.md @@ -1,12 +1,12 @@ # Shared key -A requirement for a migration to run is that the two _before_ and _after_ tables have a shared unique key. This is to elaborate and illustrate on the matter. +gh-ost requires for every migration that both the _before_ and _after_ versions of the table share the same unique not-null key columns. This page illustrates this rule. ### Introduction -Consider a classic, simple migration. The table is any normal: +Consider a simple migration, with a normal table, -``` +```sql CREATE TABLE tbl ( id bigint unsigned not null auto_increment, data varchar(255), @@ -15,54 +15,72 @@ CREATE TABLE tbl ( ) ``` -And the migration is a simple `add column ts timestamp`. - -In such migration there is no change in indexes, and in particular no change to any unique key, and specifically no change to the `PRIMARY KEY`. To run this migration, `gh-ost` would iterate the `tbl` table using the primary key, copy rows from `tbl` to the _ghost_ table `_tbl_gho` by order of `id`, and then apply binlog events onto `_tbl_gho`. - -Applying the binlog events assumes the existence of a shared unique key. For example, an `UPDATE` statement in the binary log translate to a `REPLACE` statement which `gh-ost` applies to the _ghost_ table. Such statement expects to add or replace an existing row based on given row data. In particular, it would _replace_ an existing row if a unique key violation is met. - -So `gh-ost` correlates `tbl` and `_tbl_gho` rows using a unique key. In the above example that would be the `PRIMARY KEY`. - -### Rules - -There must be a shared set of not-null columns for which there is a unique constraint in both the original table and the migration (_ghost_) table. - -### Interpreting the rules - -The same columns must be covered by a unique key in both tables. This doesn't have to be the `PRIMARY KEY`. This doesn't have to be a key of the same name. - -Upon migration, `gh-ost` inspects both the original and _ghost_ table and attempts to find at least one such unique key (or rather, a set of columns) that is shared between the two. Typically this would just be the `PRIMARY KEY`, but sometimes you may change the `PRIMARY KEY` itself, in which case `gh-ost` will look for other options. - -`gh-ost` expects unique keys where no `NULL` values are found, i.e. all columns covered by the unique key are defined as `NOT NULL`. This is implicitly true for `PRIMARY KEY`s. If no such key can be found, `gh-ost` bails out. In the event there is no such key, but you happen to _know_ your columns have no `NULL` values even though they're `NULL`-able, you may take responsibility and pass the `--allow-nullable-unique-key`. The migration will run well as long as no `NULL` values are found in the unique key's columns. Any actual `NULL`s may corrupt the migration. - -### Examples: allowed and not allowed +and the migration `add column ts timestamp`. The _after_ table version would be: +```sql +CREATE TABLE tbl ( + id bigint unsigned not null auto_increment, + data varchar(255), + more_data int, + ts timestamp, + PRIMARY KEY(id) +) ``` + +(This is also the definition of the _ghost_ table, except that that table would be called `_tbl_gho`). + +In this migration, the _before_ and _after_ versions contain the same unique not-null key (the PRIMARY KEY). To run this migration, `gh-ost` would iterate through the `tbl` table using the primary key, copy rows from `tbl` to the _ghost_ table `_tbl_gho` in primary key order, while also applying the binlog event writes from `tble` onto `_tbl_gho`. + +The applying of the binlog events is what requires the shared unique key. For example, an `UPDATE` statement to `tbl` translates to a `REPLACE` statement which `gh-ost` applies to `_tbl_gho`. A `REPLACE` statement expects to insert or replace an existing row based on its row's values and the table's unique key constraints. In particular, if inserting that row would result in a unique key violation (e.g., a row with that primary key already exists), it would _replace_ that existing row with the new values. + +So `gh-ost` correlates `tbl` and `_tbl_gho` rows one to one using a unique key. In the above example that would be the `PRIMARY KEY`. + +### Interpreting the rule + +The _before_ and _after_ versions of the table share the same unique not-null key, but: +- the key doesn't have to be the PRIMARY KEY +- the key can have a different name between the _before_ and _after_ versions (e.g., renamed via DROP INDEX and ADD INDEX) so long as it contains the exact same column(s) + +At the start of the migration, `gh-ost` inspects both the original and _ghost_ table it created, and attempts to find at least one such unique key (or rather, a set of columns) that is shared between the two. Typically this would just be the `PRIMARY KEY`, but some tables don't have primary keys, or sometimes it is the primary key that is being modified by the migration. In these cases `gh-ost` will look for other options. + +`gh-ost` expects unique keys where no `NULL` values are found, i.e. all columns contained in the unique key are defined as `NOT NULL`. This is implicitly true for primary keys. If no such key can be found, `gh-ost` bails out. + +If the table contains a unique key with nullable columns, but you know your columns contain no `NULL` values, use the `--allow-nullable-unique-key` option. The migration will run well as long as no `NULL` values are found in the unique key's columns. **Any actual `NULL`s may corrupt the migration.** + +### Examples: Allowed and Not Allowed + +```sql create table some_table ( - id int auto_increment, + id int not null auto_increment, ts timestamp, name varchar(128) not null, owner_id int not null, - loc_id int, + loc_id int not null, primary key(id), unique key name_uidx(name) ) ``` -Following are examples of migrations that are _good to run_: +Note the two unique, not-null indexes: the primary key and `name_uidx`. + +Allowed migrations: - `add column i int` -- `add key owner_idx(owner_id)` -- `add unique key owner_name_idx(owner_id, name)` - though you need to make sure to not write conflicting rows while this migration runs +- `add key owner_idx (owner_id)` +- `add unique key owner_name_idx (owner_id, name)` - **be careful not to write conflicting rows while this migration runs** - `drop key name_uidx` - `primary key` is shared between the tables -- `drop primary key, add primary key(owner_id, loc_id)` - `name_uidx` is shared between the tables and is used for migration -- `change id bigint unsigned` - the `'primary key` is used. The change of type still makes the `primary key` workable. -- `drop primary key, drop key name_uidx, create primary key(name), create unique key id_uidx(id)` - swapping the two keys. `gh-ost` is still happy because `id` is still unique in both tables. So is `name`. +- `drop primary key, add primary key(owner_id, loc_id)` - `name_uidx` is shared between the tables +- `change id bigint unsigned not null auto_increment` - the `primary key` changes datatype but not value, and can be used +- `drop primary key, drop key name_uidx, add primary key(name), add unique key id_uidx(id)` - swapping the two keys. Either `id` or `name` could be used + +Not allowed: + +- `drop primary key, drop key name_uidx` - the _ghost_ table has no unique key +- `drop primary key, drop key name_uidx, create primary key(name, owner_id)` - no shared columns to the unique keys on both tables. Even though `name` exists in the _ghost_ table's `primary key`, it is only part of the key and in itself does not guarantee uniqueness in the _ghost_ table. -Following are examples of migrations that _cannot run_: +### Workarounds -- `drop primary key, drop key name_uidx` - no unique key to _ghost_ table, so clearly cannot run -- `drop primary key, drop key name_uidx, create primary key(name, owner_id)` - no shared columns to both tables. Even though `name` exists in the _ghost_ table's `primary key`, it is only part of the key and in itself does not guarantee uniqueness in the _ghost_ table. - -Also, you cannot run a migration on a table that doesn't have some form of `unique key` in the first place, such as `some_table (id int, ts timestamp)` +If you need to change your primary key or only not-null unique index to use different columns, you will want to do it as two separate migrations: +1. `ADD UNIQUE KEY temp_pk (temp_pk_column,...)` +1. `DROP PRIMARY KEY, DROP KEY temp_pk, ADD PRIMARY KEY (temp_pk_column,...)` diff --git a/doc/throttle.md b/doc/throttle.md index bc4a315..2ebc2ee 100644 --- a/doc/throttle.md +++ b/doc/throttle.md @@ -46,6 +46,14 @@ Note that you may dynamically change both `--max-lag-millis` and the `throttle-c An example query could be: `--throttle-query="select hour(now()) between 8 and 17"` which implies throttling auto-starts `8:00am` and migration auto-resumes at `18:00pm`. +#### HTTP Throttle + +The `--throttle-http` flag allows for throttling via HTTP. Every 100ms `gh-ost` issues a `HEAD` request to the provided URL. If the response status code is not `200` throttling will kick in until a `200` response status code is returned. + +If no URL is provided or the URL provided doesn't contain the scheme then the HTTP check will be disabled. For example `--throttle-http="http://1.2.3.4:6789/throttle"` will enable the HTTP check/throttling, but `--throttle-http="1.2.3.4:6789/throttle"` will not. + +The URL can be queried and updated dynamically via [interactive interface](interactive-commands.md). + #### Manual control In addition to the above, you are able to take control and throttle the operation any time you like. diff --git a/doc/understanding-output.md b/doc/understanding-output.md index 8e139d8..6eff100 100644 --- a/doc/understanding-output.md +++ b/doc/understanding-output.md @@ -24,15 +24,15 @@ Initial output lines may look like this: 2016-05-19 17:57:11 INFO connection validated on 127.0.0.1:3306 2016-05-19 17:57:11 INFO rotate to next log name: mysql-bin.002587 2016-05-19 17:57:11 INFO connection validated on 127.0.0.1:3306 -2016-05-19 17:57:11 INFO Droppping table `mydb`.`_mytable_gst` +2016-05-19 17:57:11 INFO Dropping table `mydb`.`_mytable_gst` 2016-05-19 17:57:11 INFO Table dropped -2016-05-19 17:57:11 INFO Droppping table `mydb`.`_mytable_old` +2016-05-19 17:57:11 INFO Dropping table `mydb`.`_mytable_old` 2016-05-19 17:57:11 INFO Table dropped 2016-05-19 17:57:11 INFO Creating ghost table `mydb`.`_mytable_gst` 2016-05-19 17:57:11 INFO Ghost table created 2016-05-19 17:57:11 INFO Altering ghost table `mydb`.`_mytable_gst` 2016-05-19 17:57:11 INFO Ghost table altered -2016-05-19 17:57:11 INFO Droppping table `mydb`.`_mytable_osc` +2016-05-19 17:57:11 INFO Dropping table `mydb`.`_mytable_osc` 2016-05-19 17:57:11 INFO Table dropped 2016-05-19 17:57:11 INFO Creating changelog table `mydb`.`_mytable_osc` 2016-05-19 17:57:11 INFO Changelog table created diff --git a/doc/why-triggerless.md b/doc/why-triggerless.md index 50153b9..2ea8c81 100644 --- a/doc/why-triggerless.md +++ b/doc/why-triggerless.md @@ -16,7 +16,7 @@ Use of triggers simplifies a lot of the flow in doing a live table migration, bu Triggers are stored routines which are invoked on a per-row operation upon `INSERT`, `DELETE`, `UPDATE` on a table. They were introduced in MySQL `5.0`. -A trigger may contain a set of queries, and these queries run in the same transaction space as the query that manipulates the table. This makes for an atomicy of both the original operation on the table and the trigger-invoked operations. +A trigger may contain a set of queries, and these queries run in the same transaction space as the query that manipulates the table. This makes for an atomicity of both the original operation on the table and the trigger-invoked operations. ### Triggers, overhead diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..8a2bd2e --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,7 @@ +version: "3.5" +services: + app: + image: app + build: + context: . + dockerfile: Dockerfile.test diff --git a/go/base/context.go b/go/base/context.go index 71bb3d0..3211067 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -7,6 +7,7 @@ package base import ( "fmt" + "math" "os" "regexp" "strings" @@ -14,8 +15,11 @@ import ( "sync/atomic" "time" + "github.com/satori/go.uuid" + "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" + "github.com/outbrain/golib/log" "gopkg.in/gcfg.v1" gcfgscanner "gopkg.in/gcfg.v1/scanner" @@ -26,23 +30,23 @@ type RowsEstimateMethod string const ( TableStatusRowsEstimate RowsEstimateMethod = "TableStatusRowsEstimate" - ExplainRowsEstimate = "ExplainRowsEstimate" - CountRowsEstimate = "CountRowsEstimate" + ExplainRowsEstimate RowsEstimateMethod = "ExplainRowsEstimate" + CountRowsEstimate RowsEstimateMethod = "CountRowsEstimate" ) type CutOver int const ( - CutOverAtomic CutOver = iota - CutOverTwoStep = iota + CutOverAtomic CutOver = iota + CutOverTwoStep ) type ThrottleReasonHint string const ( NoThrottleReasonHint ThrottleReasonHint = "NoThrottleReasonHint" - UserCommandThrottleReasonHint = "UserCommandThrottleReasonHint" - LeavingHibernationThrottleReasonHint = "LeavingHibernationThrottleReasonHint" + UserCommandThrottleReasonHint ThrottleReasonHint = "UserCommandThrottleReasonHint" + LeavingHibernationThrottleReasonHint ThrottleReasonHint = "LeavingHibernationThrottleReasonHint" ) const ( @@ -71,9 +75,12 @@ func NewThrottleCheckResult(throttle bool, reason string, reasonHint ThrottleRea // MigrationContext has the general, global state of migration. It is used by // all components throughout the migration process. type MigrationContext struct { - DatabaseName string - OriginalTableName string - AlterStatement string + Uuid string + + DatabaseName string + OriginalTableName string + AlterStatement string + AlterStatementOptions string // anything following the 'ALTER TABLE [schema.]table' from AlterStatement CountTableRows bool ConcurrentCountTableRows bool @@ -82,17 +89,26 @@ type MigrationContext struct { SwitchToRowBinlogFormat bool AssumeRBR bool SkipForeignKeyChecks bool + SkipStrictMode bool NullableUniqueKeyAllowed bool ApproveRenamedColumns bool SkipRenamedColumns bool IsTungsten bool DiscardForeignKeys bool + AliyunRDS bool + GoogleCloudPlatform bool + AzureMySQL bool config ContextConfig configMutex *sync.Mutex ConfigFile string CliUser string CliPassword string + UseTLS bool + TLSAllowInsecure bool + TLSCACertificate string + TLSCertificate string + TLSKey string CliMasterUser string CliMasterPassword string @@ -106,6 +122,7 @@ type MigrationContext struct { ThrottleAdditionalFlagFile string throttleQuery string throttleHTTP string + IgnoreHTTPErrors bool ThrottleCommandedByUser int64 HibernateUntil int64 maxLoad LoadMap @@ -114,10 +131,15 @@ type MigrationContext struct { CriticalLoadHibernateSeconds int64 PostponeCutOverFlagFile string CutOverLockTimeoutSeconds int64 + CutOverExponentialBackoff bool + ExponentialBackoffMaxInterval int64 ForceNamedCutOverCommand bool + ForceNamedPanicCommand bool PanicFlagFile string HooksPath string HooksHintMessage string + HooksHintOwner string + HooksHintToken string DropServeSocket bool ServeSocketFile string @@ -157,6 +179,7 @@ type MigrationContext struct { pointOfInterestTime time.Time pointOfInterestTimeMutex *sync.Mutex CurrentLag int64 + currentProgress uint64 ThrottleHTTPStatusCode int64 controlReplicasLagResult mysql.ReplicationLagResult TotalRowsCopied int64 @@ -179,8 +202,10 @@ type MigrationContext struct { OriginalTableColumnsOnApplier *sql.ColumnList OriginalTableColumns *sql.ColumnList + OriginalTableVirtualColumns *sql.ColumnList OriginalTableUniqueKeys [](*sql.UniqueKey) GhostTableColumns *sql.ColumnList + GhostTableVirtualColumns *sql.ColumnList GhostTableUniqueKeys [](*sql.UniqueKey) UniqueKey *sql.UniqueKey SharedColumns *sql.ColumnList @@ -196,7 +221,24 @@ type MigrationContext struct { recentBinlogCoordinates mysql.BinlogCoordinates - CanStopStreaming func() bool + Log Logger +} + +type Logger interface { + Debug(args ...interface{}) + Debugf(format string, args ...interface{}) + Info(args ...interface{}) + Infof(format string, args ...interface{}) + Warning(args ...interface{}) error + Warningf(format string, args ...interface{}) error + Error(args ...interface{}) error + Errorf(format string, args ...interface{}) error + Errore(err error) error + Fatal(args ...interface{}) error + Fatalf(format string, args ...interface{}) error + Fatale(err error) error + SetLevel(level log.LogLevel) + SetPrintStackTrace(printStackTraceFlag bool) } type ContextConfig struct { @@ -212,14 +254,9 @@ type ContextConfig struct { } } -var context *MigrationContext - -func init() { - context = newMigrationContext() -} - -func newMigrationContext() *MigrationContext { +func NewMigrationContext() *MigrationContext { return &MigrationContext{ + Uuid: uuid.NewV4().String(), defaultNumRetries: 60, ChunkSize: 1000, InspectorConnectionConfig: mysql.NewConnectionConfig(), @@ -236,14 +273,10 @@ func newMigrationContext() *MigrationContext { pointOfInterestTimeMutex: &sync.Mutex{}, ColumnRenameMap: make(map[string]string), PanicAbort: make(chan error), + Log: NewDefaultLogger(), } } -// GetMigrationContext -func GetMigrationContext() *MigrationContext { - return context -} - func getSafeTableName(baseName string, suffix string) string { name := fmt.Sprintf("_%s_%s", baseName, suffix) if len(name) <= mysql.MaxTableNameLength { @@ -349,6 +382,14 @@ func (this *MigrationContext) SetCutOverLockTimeoutSeconds(timeoutSeconds int64) return nil } +func (this *MigrationContext) SetExponentialBackoffMaxInterval(intervalSeconds int64) error { + if intervalSeconds < 2 { + return fmt.Errorf("Minimal maximum interval is 2sec. Timeout remains at %d", this.ExponentialBackoffMaxInterval) + } + this.ExponentialBackoffMaxInterval = intervalSeconds + return nil +} + func (this *MigrationContext) SetDefaultNumRetries(retries int64) { this.throttleMutex.Lock() defer this.throttleMutex.Unlock() @@ -413,6 +454,20 @@ func (this *MigrationContext) MarkRowCopyEndTime() { this.RowCopyEndTime = time.Now() } +func (this *MigrationContext) GetCurrentLagDuration() time.Duration { + return time.Duration(atomic.LoadInt64(&this.CurrentLag)) +} + +func (this *MigrationContext) GetProgressPct() float64 { + return math.Float64frombits(atomic.LoadUint64(&this.currentProgress)) +} + +func (this *MigrationContext) SetProgressPct(progressPct float64) { + atomic.StoreUint64(&this.currentProgress, math.Float64bits(progressPct)) +} + +// math.Float64bits([f=0..100]) + // GetTotalRowsCopied returns the accurate number of rows being copied (affected) // This is not exactly the same as the rows being iterated via chunks, but potentially close enough func (this *MigrationContext) GetTotalRowsCopied() int64 { @@ -543,6 +598,13 @@ func (this *MigrationContext) SetThrottleHTTP(throttleHTTP string) { this.throttleHTTP = throttleHTTP } +func (this *MigrationContext) SetIgnoreHTTPErrors(ignoreHTTPErrors bool) { + this.throttleHTTPMutex.Lock() + defer this.throttleHTTPMutex.Unlock() + + this.IgnoreHTTPErrors = ignoreHTTPErrors +} + func (this *MigrationContext) GetMaxLoad() LoadMap { this.throttleMutex.Lock() defer this.throttleMutex.Unlock() @@ -689,6 +751,13 @@ func (this *MigrationContext) ApplyCredentials() { } } +func (this *MigrationContext) SetupTLS() error { + if this.UseTLS { + return this.InspectorConnectionConfig.UseTLS(this.TLSCACertificate, this.TLSCertificate, this.TLSKey, this.TLSAllowInsecure) + } + return nil +} + // ReadConfigFile attempts to read the config file, if it exists func (this *MigrationContext) ReadConfigFile() error { this.configMutex.Lock() diff --git a/go/base/context_test.go b/go/base/context_test.go index b3a98eb..8a9c6a5 100644 --- a/go/base/context_test.go +++ b/go/base/context_test.go @@ -19,27 +19,27 @@ func init() { func TestGetTableNames(t *testing.T) { { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "some_table" test.S(t).ExpectEquals(context.GetOldTableName(), "_some_table_del") test.S(t).ExpectEquals(context.GetGhostTableName(), "_some_table_gho") test.S(t).ExpectEquals(context.GetChangelogTableName(), "_some_table_ghc") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "a123456789012345678901234567890123456789012345678901234567890" test.S(t).ExpectEquals(context.GetOldTableName(), "_a1234567890123456789012345678901234567890123456789012345678_del") test.S(t).ExpectEquals(context.GetGhostTableName(), "_a1234567890123456789012345678901234567890123456789012345678_gho") test.S(t).ExpectEquals(context.GetChangelogTableName(), "_a1234567890123456789012345678901234567890123456789012345678_ghc") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "a123456789012345678901234567890123456789012345678901234567890123" oldTableName := context.GetOldTableName() test.S(t).ExpectEquals(oldTableName, "_a1234567890123456789012345678901234567890123456789012345678_del") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "a123456789012345678901234567890123456789012345678901234567890123" context.TimestampOldTable = true longForm := "Jan 2, 2006 at 3:04pm (MST)" @@ -48,7 +48,7 @@ func TestGetTableNames(t *testing.T) { test.S(t).ExpectEquals(oldTableName, "_a1234567890123456789012345678901234567890123_20130203195400_del") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "foo_bar_baz" context.ForceTmpTableName = "tmp" test.S(t).ExpectEquals(context.GetOldTableName(), "_tmp_del") diff --git a/go/base/default_logger.go b/go/base/default_logger.go new file mode 100644 index 0000000..be6b1f2 --- /dev/null +++ b/go/base/default_logger.go @@ -0,0 +1,73 @@ +package base + +import ( + "github.com/outbrain/golib/log" +) + +type simpleLogger struct{} + +func NewDefaultLogger() *simpleLogger { + return &simpleLogger{} +} + +func (*simpleLogger) Debug(args ...interface{}) { + log.Debug(args[0].(string), args[1:]) + return +} + +func (*simpleLogger) Debugf(format string, args ...interface{}) { + log.Debugf(format, args...) + return +} + +func (*simpleLogger) Info(args ...interface{}) { + log.Info(args[0].(string), args[1:]) + return +} + +func (*simpleLogger) Infof(format string, args ...interface{}) { + log.Infof(format, args...) + return +} + +func (*simpleLogger) Warning(args ...interface{}) error { + return log.Warning(args[0].(string), args[1:]) +} + +func (*simpleLogger) Warningf(format string, args ...interface{}) error { + return log.Warningf(format, args...) +} + +func (*simpleLogger) Error(args ...interface{}) error { + return log.Error(args[0].(string), args[1:]) +} + +func (*simpleLogger) Errorf(format string, args ...interface{}) error { + return log.Errorf(format, args...) +} + +func (*simpleLogger) Errore(err error) error { + return log.Errore(err) +} + +func (*simpleLogger) Fatal(args ...interface{}) error { + return log.Fatal(args[0].(string), args[1:]) +} + +func (*simpleLogger) Fatalf(format string, args ...interface{}) error { + return log.Fatalf(format, args...) +} + +func (*simpleLogger) Fatale(err error) error { + return log.Fatale(err) +} + +func (*simpleLogger) SetLevel(level log.LogLevel) { + log.SetLevel(level) + return +} + +func (*simpleLogger) SetPrintStackTrace(printStackTraceFlag bool) { + log.SetPrintStackTrace(printStackTraceFlag) + return +} diff --git a/go/base/utils.go b/go/base/utils.go index 9c47407..ed14514 100644 --- a/go/base/utils.go +++ b/go/base/utils.go @@ -15,7 +15,6 @@ import ( gosql "database/sql" "github.com/github/gh-ost/go/mysql" - "github.com/outbrain/golib/log" ) var ( @@ -41,10 +40,9 @@ func FileExists(fileName string) bool { func TouchFile(fileName string) error { f, err := os.OpenFile(fileName, os.O_APPEND|os.O_CREATE, 0755) if err != nil { - return (err) + return err } - defer f.Close() - return nil + return f.Close() } // StringContainsAll returns true if `s` contains all non empty given `substrings` @@ -65,20 +63,31 @@ func StringContainsAll(s string, substrings ...string) bool { return nonEmptyStringsFound } -func ValidateConnection(db *gosql.DB, connectionConfig *mysql.ConnectionConfig, name string) (string, error) { - query := `select @@global.port, @@global.version` +func ValidateConnection(db *gosql.DB, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) (string, error) { + versionQuery := `select @@global.version` var port, extraPort int var version string - if err := db.QueryRow(query).Scan(&port, &version); err != nil { + if err := db.QueryRow(versionQuery).Scan(&version); err != nil { return "", err } extraPortQuery := `select @@global.extra_port` if err := db.QueryRow(extraPortQuery).Scan(&extraPort); err != nil { // swallow this error. not all servers support extra_port } + // AliyunRDS set users port to "NULL", replace it by gh-ost param + // GCP set users port to "NULL", replace it by gh-ost param + // Azure MySQL set users port to a different value by design, replace it by gh-ost para + if migrationContext.AliyunRDS || migrationContext.GoogleCloudPlatform || migrationContext.AzureMySQL { + port = connectionConfig.Key.Port + } else { + portQuery := `select @@global.port` + if err := db.QueryRow(portQuery).Scan(&port); err != nil { + return "", err + } + } if connectionConfig.Key.Port == port || (extraPort > 0 && connectionConfig.Key.Port == extraPort) { - log.Infof("%s connection validated on %+v", name, connectionConfig.Key) + migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key) return version, nil } else if extraPort == 0 { return "", fmt.Errorf("Unexpected database port reported: %+v", port) diff --git a/go/binlog/binlog_dml_event.go b/go/binlog/binlog_dml_event.go index 4fab87a..2c7aa36 100644 --- a/go/binlog/binlog_dml_event.go +++ b/go/binlog/binlog_dml_event.go @@ -7,17 +7,18 @@ package binlog import ( "fmt" - "github.com/github/gh-ost/go/sql" "strings" + + "github.com/github/gh-ost/go/sql" ) type EventDML string const ( NotDML EventDML = "NoDML" - InsertDML = "Insert" - UpdateDML = "Update" - DeleteDML = "Delete" + InsertDML EventDML = "Insert" + UpdateDML EventDML = "Update" + DeleteDML EventDML = "Delete" ) func ToEventDML(description string) EventDML { diff --git a/go/binlog/binlog_entry.go b/go/binlog/binlog_entry.go index bb70bc5..5650acc 100644 --- a/go/binlog/binlog_entry.go +++ b/go/binlog/binlog_entry.go @@ -26,7 +26,7 @@ func NewBinlogEntry(logFile string, logPos uint64) *BinlogEntry { return binlogEntry } -// NewBinlogEntry creates an empty, ready to go BinlogEntry object +// NewBinlogEntryAt creates an empty, ready to go BinlogEntry object func NewBinlogEntryAt(coordinates mysql.BinlogCoordinates) *BinlogEntry { binlogEntry := &BinlogEntry{ Coordinates: coordinates, @@ -41,7 +41,7 @@ func (this *BinlogEntry) Duplicate() *BinlogEntry { return binlogEntry } -// Duplicate creates and returns a new binlog entry, with some of the attributes pre-assigned +// String() returns a string representation of this binlog entry func (this *BinlogEntry) String() string { return fmt.Sprintf("[BinlogEntry at %+v; dml:%+v]", this.Coordinates, this.DmlEvent) } diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index 9feca87..bc80cb5 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -13,41 +13,42 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/outbrain/golib/log" gomysql "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/replication" "golang.org/x/net/context" ) type GoMySQLReader struct { + migrationContext *base.MigrationContext connectionConfig *mysql.ConnectionConfig binlogSyncer *replication.BinlogSyncer binlogStreamer *replication.BinlogStreamer currentCoordinates mysql.BinlogCoordinates currentCoordinatesMutex *sync.Mutex LastAppliedRowsEventHint mysql.BinlogCoordinates - MigrationContext *base.MigrationContext } -func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *GoMySQLReader, err error) { +func NewGoMySQLReader(migrationContext *base.MigrationContext) (binlogReader *GoMySQLReader, err error) { binlogReader = &GoMySQLReader{ - connectionConfig: connectionConfig, + migrationContext: migrationContext, + connectionConfig: migrationContext.InspectorConnectionConfig, currentCoordinates: mysql.BinlogCoordinates{}, currentCoordinatesMutex: &sync.Mutex{}, binlogSyncer: nil, binlogStreamer: nil, - MigrationContext: base.GetMigrationContext(), } - serverId := uint32(binlogReader.MigrationContext.ReplicaServerId) + serverId := uint32(migrationContext.ReplicaServerId) - binlogSyncerConfig := &replication.BinlogSyncerConfig{ - ServerID: serverId, - Flavor: "mysql", - Host: connectionConfig.Key.Hostname, - Port: uint16(connectionConfig.Key.Port), - User: connectionConfig.User, - Password: connectionConfig.Password, + binlogSyncerConfig := replication.BinlogSyncerConfig{ + ServerID: serverId, + Flavor: "mysql", + Host: binlogReader.connectionConfig.Key.Hostname, + Port: uint16(binlogReader.connectionConfig.Key.Port), + User: binlogReader.connectionConfig.User, + Password: binlogReader.connectionConfig.Password, + TLSConfig: binlogReader.connectionConfig.TLSConfig(), + UseDecimal: true, } binlogReader.binlogSyncer = replication.NewBinlogSyncer(binlogSyncerConfig) @@ -57,12 +58,12 @@ func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *G // ConnectBinlogStreamer func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordinates) (err error) { if coordinates.IsEmpty() { - return log.Errorf("Emptry coordinates at ConnectBinlogStreamer()") + return this.migrationContext.Log.Errorf("Empty coordinates at ConnectBinlogStreamer()") } this.currentCoordinates = coordinates - log.Infof("Connecting binlog streamer at %+v", this.currentCoordinates) - // Start sync with sepcified binlog file and position + this.migrationContext.Log.Infof("Connecting binlog streamer at %+v", this.currentCoordinates) + // Start sync with specified binlog file and position this.binlogStreamer, err = this.binlogSyncer.StartSync(gomysql.Position{this.currentCoordinates.LogFile, uint32(this.currentCoordinates.LogPos)}) return err @@ -78,7 +79,7 @@ func (this *GoMySQLReader) GetCurrentBinlogCoordinates() *mysql.BinlogCoordinate // StreamEvents func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEvent *replication.RowsEvent, entriesChannel chan<- *BinlogEntry) error { if this.currentCoordinates.SmallerThanOrEquals(&this.LastAppliedRowsEventHint) { - log.Debugf("Skipping handled query at %+v", this.currentCoordinates) + this.migrationContext.Log.Debugf("Skipping handled query at %+v", this.currentCoordinates) return nil } @@ -113,8 +114,8 @@ func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEven binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row) } } - // The channel will do the throttling. Whoever is reding from the channel - // decides whether action is taken sycnhronously (meaning we wait before + // The channel will do the throttling. Whoever is reading from the channel + // decides whether action is taken synchronously (meaning we wait before // next iteration) or asynchronously (we keep pushing more events) // In reality, reads will be synchronous entriesChannel <- binlogEntry @@ -147,23 +148,19 @@ func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesCha defer this.currentCoordinatesMutex.Unlock() this.currentCoordinates.LogFile = string(rotateEvent.NextLogName) }() - log.Infof("rotate to next log name: %s", rotateEvent.NextLogName) + this.migrationContext.Log.Infof("rotate to next log from %s:%d to %s", this.currentCoordinates.LogFile, int64(ev.Header.LogPos), rotateEvent.NextLogName) } else if rowsEvent, ok := ev.Event.(*replication.RowsEvent); ok { if err := this.handleRowsEvent(ev, rowsEvent, entriesChannel); err != nil { return err } } } - log.Debugf("done streaming events") + this.migrationContext.Log.Debugf("done streaming events") return nil } func (this *GoMySQLReader) Close() error { - // Historically there was a: - // this.binlogSyncer.Close() - // here. A new go-mysql version closes the binlog syncer connection independently. - // I will go against the sacred rules of comments and just leave this here. - // This is the year 2017. Let's see what year these comments get deleted. + this.binlogSyncer.Close() return nil } diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 4509e4a..b8557f9 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -14,6 +14,8 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/logic" + "github.com/github/gh-ost/go/sql" + _ "github.com/go-sql-driver/mysql" "github.com/outbrain/golib/log" "golang.org/x/crypto/ssh/terminal" @@ -30,7 +32,7 @@ func acceptSignals(migrationContext *base.MigrationContext) { for sig := range c { switch sig { case syscall.SIGHUP: - log.Infof("Received SIGHUP. Reloading configuration") + migrationContext.Log.Infof("Received SIGHUP. Reloading configuration") if err := migrationContext.ReadConfigFile(); err != nil { log.Errore(err) } else { @@ -43,11 +45,11 @@ func acceptSignals(migrationContext *base.MigrationContext) { // main is the application's entry point. It will either spawn a CLI or HTTP interfaces. func main() { - migrationContext := base.GetMigrationContext() - + migrationContext := base.NewMigrationContext() flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)") - flag.StringVar(&migrationContext.AssumeMasterHostname, "assume-master-host", "", "(optional) explicitly tell gh-ost the identity of the master. Format: some.host.com[:port] This is useful in master-master setups where you wish to pick an explicit master, or in a tungsten-replicator where gh-ost is unabel to determine the master") + flag.StringVar(&migrationContext.AssumeMasterHostname, "assume-master-host", "", "(optional) explicitly tell gh-ost the identity of the master. Format: some.host.com[:port] This is useful in master-master setups where you wish to pick an explicit master, or in a tungsten-replicator where gh-ost is unable to determine the master") flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)") + flag.Float64Var(&migrationContext.InspectorConnectionConfig.Timeout, "mysql-timeout", 0.0, "Connect, read and write timeout for MySQL") flag.StringVar(&migrationContext.CliUser, "user", "", "MySQL user") flag.StringVar(&migrationContext.CliPassword, "password", "", "MySQL password") flag.StringVar(&migrationContext.CliMasterUser, "master-user", "", "MySQL user on master, if different from that on replica. Requires --assume-master-host") @@ -55,6 +57,12 @@ func main() { flag.StringVar(&migrationContext.ConfigFile, "conf", "", "Config file") askPass := flag.Bool("ask-pass", false, "prompt for MySQL password") + flag.BoolVar(&migrationContext.UseTLS, "ssl", false, "Enable SSL encrypted connections to MySQL hosts") + flag.StringVar(&migrationContext.TLSCACertificate, "ssl-ca", "", "CA certificate in PEM format for TLS connections to MySQL hosts. Requires --ssl") + flag.StringVar(&migrationContext.TLSCertificate, "ssl-cert", "", "Certificate in PEM format for TLS connections to MySQL hosts. Requires --ssl") + flag.StringVar(&migrationContext.TLSKey, "ssl-key", "", "Key in PEM format for TLS connections to MySQL hosts. Requires --ssl") + flag.BoolVar(&migrationContext.TLSAllowInsecure, "ssl-allow-insecure", false, "Skips verification of MySQL hosts' certificate chain and host name. Requires --ssl") + flag.StringVar(&migrationContext.DatabaseName, "database", "", "database name (mandatory)") flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)") flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory)") @@ -68,6 +76,10 @@ func main() { flag.BoolVar(&migrationContext.IsTungsten, "tungsten", false, "explicitly let gh-ost know that you are running on a tungsten-replication based topology (you are likely to also provide --assume-master-host)") flag.BoolVar(&migrationContext.DiscardForeignKeys, "discard-foreign-keys", false, "DANGER! This flag will migrate a table that has foreign keys and will NOT create foreign keys on the ghost table, thus your altered table will have NO foreign keys. This is useful for intentional dropping of foreign keys") flag.BoolVar(&migrationContext.SkipForeignKeyChecks, "skip-foreign-key-checks", false, "set to 'true' when you know for certain there are no foreign keys on your table, and wish to skip the time it takes for gh-ost to verify that") + flag.BoolVar(&migrationContext.SkipStrictMode, "skip-strict-mode", false, "explicitly tell gh-ost binlog applier not to enforce strict sql mode") + flag.BoolVar(&migrationContext.AliyunRDS, "aliyun-rds", false, "set to 'true' when you execute on Aliyun RDS.") + flag.BoolVar(&migrationContext.GoogleCloudPlatform, "gcp", false, "set to 'true' when you execute on a 1st generation Google Cloud Platform (GCP).") + flag.BoolVar(&migrationContext.AzureMySQL, "azure", false, "set to 'true' when you execute on Azure Database on MySQL.") executeFlag := flag.Bool("execute", false, "actually execute the alter & migrate the table. Default is noop: do some tests and exit") flag.BoolVar(&migrationContext.TestOnReplica, "test-on-replica", false, "Have the migration run on a replica, not on the master. At the end of migration replication is stopped, and tables are swapped and immediately swap-revert. Replication remains stopped and you can compare the two tables for building trust") @@ -80,9 +92,12 @@ func main() { flag.BoolVar(&migrationContext.TimestampOldTable, "timestamp-old-table", false, "Use a timestamp in old table name. This makes old table names unique and non conflicting cross migrations") cutOver := flag.String("cut-over", "atomic", "choose cut-over type (default|atomic, two-step)") flag.BoolVar(&migrationContext.ForceNamedCutOverCommand, "force-named-cut-over", false, "When true, the 'unpostpone|cut-over' interactive command must name the migrated table") + flag.BoolVar(&migrationContext.ForceNamedPanicCommand, "force-named-panic", false, "When true, the 'panic' interactive command must name the migrated table") flag.BoolVar(&migrationContext.SwitchToRowBinlogFormat, "switch-to-rbr", false, "let this tool automatically switch binary log format to 'ROW' on the replica, if needed. The format will NOT be switched back. I'm too scared to do that, and wish to protect you if you happen to execute another migration while this one is running") flag.BoolVar(&migrationContext.AssumeRBR, "assume-rbr", false, "set to 'true' when you know for certain your server uses 'ROW' binlog_format. gh-ost is unable to tell, event after reading binlog_format, whether the replication process does indeed use 'ROW', and restarts replication to be certain RBR setting is applied. Such operation requires SUPER privileges which you might not have. Setting this flag avoids restarting replication and you can proceed to use gh-ost without SUPER privileges") + flag.BoolVar(&migrationContext.CutOverExponentialBackoff, "cut-over-exponential-backoff", false, "Wait exponentially longer intervals between failed cut-over attempts. Wait intervals obey a maximum configurable with 'exponential-backoff-max-interval').") + exponentialBackoffMaxInterval := flag.Int64("exponential-backoff-max-interval", 64, "Maximum number of seconds to wait between attempts when performing various operations with exponential backoff.") chunkSize := flag.Int64("chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 100-100,000)") dmlBatchSize := flag.Int64("dml-batch-size", 10, "batch size for DML events to apply in a single transaction (range 1-100)") defaultRetries := flag.Int64("default-retries", 60, "Default number of retries for various operations before panicking") @@ -94,6 +109,7 @@ func main() { throttleControlReplicas := flag.String("throttle-control-replicas", "", "List of replicas on which to check for lag; comma delimited. Example: myhost1.com:3306,myhost2.com,myhost3.com:3307") throttleQuery := flag.String("throttle-query", "", "when given, issued (every second) to check if operation should throttle. Expecting to return zero for no-throttle, >0 for throttle. Query is issued on the migrated server. Make sure this query is lightweight") throttleHTTP := flag.String("throttle-http", "", "when given, gh-ost checks given URL via HEAD request; any response code other than 200 (OK) causes throttling; make sure it has low latency response") + ignoreHTTPErrors := flag.Bool("ignore-http-errors", false, "ignore HTTP connection errors during throttle check") heartbeatIntervalMillis := flag.Int64("heartbeat-interval-millis", 100, "how frequently would gh-ost inject a heartbeat value") flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered") flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-ost.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-ost operations") @@ -106,6 +122,8 @@ func main() { flag.StringVar(&migrationContext.HooksPath, "hooks-path", "", "directory where hook files are found (default: empty, ie. hooks disabled). Hook files found on this path, and conforming to hook naming conventions will be executed") flag.StringVar(&migrationContext.HooksHintMessage, "hooks-hint", "", "arbitrary message to be injected to hooks via GH_OST_HOOKS_HINT, for your convenience") + flag.StringVar(&migrationContext.HooksHintOwner, "hooks-hint-owner", "", "arbitrary name of owner to be injected to hooks via GH_OST_HOOKS_HINT_OWNER, for your convenience") + flag.StringVar(&migrationContext.HooksHintToken, "hooks-hint-token", "", "arbitrary token to be injected to hooks via GH_OST_HOOKS_HINT_TOKEN, for your convenience") flag.UintVar(&migrationContext.ReplicaServerId, "replica-server-id", 99999, "server id used by gh-ost process. Default: 99999") @@ -121,6 +139,7 @@ func main() { version := flag.Bool("version", false, "Print version & exit") checkFlag := flag.Bool("check-flag", false, "Check if another flag exists/supported. This allows for cross-version scripting. Exits with 0 when all additional provided flags exist, nonzero otherwise. You must provide (dummy) values for flags that require a value. Example: gh-ost --check-flag --cut-over-lock-timeout-seconds --nice-ratio 0") flag.StringVar(&migrationContext.ForceTmpTableName, "force-table-names", "", "table name prefix to be used on the temporary tables") + flag.CommandLine.SetOutput(os.Stdout) flag.Parse() @@ -128,7 +147,7 @@ func main() { return } if *help { - fmt.Fprintf(os.Stderr, "Usage of gh-ost:\n") + fmt.Fprintf(os.Stdout, "Usage of gh-ost:\n") flag.PrintDefaults() return } @@ -141,57 +160,80 @@ func main() { return } - log.SetLevel(log.ERROR) + migrationContext.Log.SetLevel(log.ERROR) if *verbose { - log.SetLevel(log.INFO) + migrationContext.Log.SetLevel(log.INFO) } if *debug { - log.SetLevel(log.DEBUG) + migrationContext.Log.SetLevel(log.DEBUG) } if *stack { - log.SetPrintStackTrace(*stack) + migrationContext.Log.SetPrintStackTrace(*stack) } if *quiet { // Override!! - log.SetLevel(log.ERROR) + migrationContext.Log.SetLevel(log.ERROR) } - if migrationContext.DatabaseName == "" { - log.Fatalf("--database must be provided and database name must not be empty") - } - if migrationContext.OriginalTableName == "" { - log.Fatalf("--table must be provided and table name must not be empty") - } if migrationContext.AlterStatement == "" { log.Fatalf("--alter must be provided and statement must not be empty") } + parser := sql.NewParserFromAlterStatement(migrationContext.AlterStatement) + migrationContext.AlterStatementOptions = parser.GetAlterStatementOptions() + + if migrationContext.DatabaseName == "" { + if parser.HasExplicitSchema() { + migrationContext.DatabaseName = parser.GetExplicitSchema() + } else { + log.Fatalf("--database must be provided and database name must not be empty, or --alter must specify database name") + } + } + if migrationContext.OriginalTableName == "" { + if parser.HasExplicitTable() { + migrationContext.OriginalTableName = parser.GetExplicitTable() + } else { + log.Fatalf("--table must be provided and table name must not be empty, or --alter must specify table name") + } + } migrationContext.Noop = !(*executeFlag) if migrationContext.AllowedRunningOnMaster && migrationContext.TestOnReplica { - log.Fatalf("--allow-on-master and --test-on-replica are mutually exclusive") + migrationContext.Log.Fatalf("--allow-on-master and --test-on-replica are mutually exclusive") } if migrationContext.AllowedRunningOnMaster && migrationContext.MigrateOnReplica { - log.Fatalf("--allow-on-master and --migrate-on-replica are mutually exclusive") + migrationContext.Log.Fatalf("--allow-on-master and --migrate-on-replica are mutually exclusive") } if migrationContext.MigrateOnReplica && migrationContext.TestOnReplica { - log.Fatalf("--migrate-on-replica and --test-on-replica are mutually exclusive") + migrationContext.Log.Fatalf("--migrate-on-replica and --test-on-replica are mutually exclusive") } if migrationContext.SwitchToRowBinlogFormat && migrationContext.AssumeRBR { - log.Fatalf("--switch-to-rbr and --assume-rbr are mutually exclusive") + migrationContext.Log.Fatalf("--switch-to-rbr and --assume-rbr are mutually exclusive") } if migrationContext.TestOnReplicaSkipReplicaStop { if !migrationContext.TestOnReplica { - log.Fatalf("--test-on-replica-skip-replica-stop requires --test-on-replica to be enabled") + migrationContext.Log.Fatalf("--test-on-replica-skip-replica-stop requires --test-on-replica to be enabled") } - log.Warning("--test-on-replica-skip-replica-stop enabled. We will not stop replication before cut-over. Ensure you have a plugin that does this.") + migrationContext.Log.Warning("--test-on-replica-skip-replica-stop enabled. We will not stop replication before cut-over. Ensure you have a plugin that does this.") } if migrationContext.CliMasterUser != "" && migrationContext.AssumeMasterHostname == "" { - log.Fatalf("--master-user requires --assume-master-host") + migrationContext.Log.Fatalf("--master-user requires --assume-master-host") } if migrationContext.CliMasterPassword != "" && migrationContext.AssumeMasterHostname == "" { - log.Fatalf("--master-password requires --assume-master-host") + migrationContext.Log.Fatalf("--master-password requires --assume-master-host") + } + if migrationContext.TLSCACertificate != "" && !migrationContext.UseTLS { + migrationContext.Log.Fatalf("--ssl-ca requires --ssl") + } + if migrationContext.TLSCertificate != "" && !migrationContext.UseTLS { + migrationContext.Log.Fatalf("--ssl-cert requires --ssl") + } + if migrationContext.TLSKey != "" && !migrationContext.UseTLS { + migrationContext.Log.Fatalf("--ssl-key requires --ssl") + } + if migrationContext.TLSAllowInsecure && !migrationContext.UseTLS { + migrationContext.Log.Fatalf("--ssl-allow-insecure requires --ssl") } if *replicationLagQuery != "" { - log.Warningf("--replication-lag-query is deprecated") + migrationContext.Log.Warningf("--replication-lag-query is deprecated") } switch *cutOver { @@ -200,19 +242,19 @@ func main() { case "two-step": migrationContext.CutOverType = base.CutOverTwoStep default: - log.Fatalf("Unknown cut-over: %s", *cutOver) + migrationContext.Log.Fatalf("Unknown cut-over: %s", *cutOver) } if err := migrationContext.ReadConfigFile(); err != nil { - log.Fatale(err) + migrationContext.Log.Fatale(err) } if err := migrationContext.ReadThrottleControlReplicaKeys(*throttleControlReplicas); err != nil { - log.Fatale(err) + migrationContext.Log.Fatale(err) } if err := migrationContext.ReadMaxLoad(*maxLoad); err != nil { - log.Fatale(err) + migrationContext.Log.Fatale(err) } if err := migrationContext.ReadCriticalLoad(*criticalLoad); err != nil { - log.Fatale(err) + migrationContext.Log.Fatale(err) } if migrationContext.ServeSocketFile == "" { migrationContext.ServeSocketFile = fmt.Sprintf("/tmp/gh-ost.%s.%s.sock", migrationContext.DatabaseName, migrationContext.OriginalTableName) @@ -221,7 +263,7 @@ func main() { fmt.Println("Password:") bytePassword, err := terminal.ReadPassword(int(syscall.Stdin)) if err != nil { - log.Fatale(err) + migrationContext.Log.Fatale(err) } migrationContext.CliPassword = string(bytePassword) } @@ -232,20 +274,27 @@ func main() { migrationContext.SetMaxLagMillisecondsThrottleThreshold(*maxLagMillis) migrationContext.SetThrottleQuery(*throttleQuery) migrationContext.SetThrottleHTTP(*throttleHTTP) + migrationContext.SetIgnoreHTTPErrors(*ignoreHTTPErrors) migrationContext.SetDefaultNumRetries(*defaultRetries) migrationContext.ApplyCredentials() + if err := migrationContext.SetupTLS(); err != nil { + migrationContext.Log.Fatale(err) + } if err := migrationContext.SetCutOverLockTimeoutSeconds(*cutOverLockTimeoutSeconds); err != nil { - log.Errore(err) + migrationContext.Log.Errore(err) + } + if err := migrationContext.SetExponentialBackoffMaxInterval(*exponentialBackoffMaxInterval); err != nil { + migrationContext.Log.Errore(err) } log.Infof("starting gh-ost %+v", AppVersion) acceptSignals(migrationContext) - migrator := logic.NewMigrator() + migrator := logic.NewMigrator(migrationContext) err := migrator.Migrate() if err != nil { migrator.ExecOnFailureHook() - log.Fatale(err) + migrationContext.Log.Fatale(err) } fmt.Fprintf(os.Stdout, "# Done\n") } diff --git a/go/logic/applier.go b/go/logic/applier.go index 9e645f4..fb4bc8d 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -16,65 +16,92 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/outbrain/golib/log" "github.com/outbrain/golib/sqlutils" + "sync" ) const ( atomicCutOverMagicHint = "ghost-cut-over-sentry" ) +type dmlBuildResult struct { + query string + args []interface{} + rowsDelta int64 + err error +} + +func newDmlBuildResult(query string, args []interface{}, rowsDelta int64, err error) *dmlBuildResult { + return &dmlBuildResult{ + query: query, + args: args, + rowsDelta: rowsDelta, + err: err, + } +} + +func newDmlBuildResultError(err error) *dmlBuildResult { + return &dmlBuildResult{ + err: err, + } +} + // Applier connects and writes the the applier-server, which is the server where migration // happens. This is typically the master, but could be a replica when `--test-on-replica` or // `--execute-on-replica` are given. // Applier is the one to actually write row data and apply binlog events onto the ghost table. // It is where the ghost & changelog tables get created. It is where the cut-over phase happens. type Applier struct { - connectionConfig *mysql.ConnectionConfig - db *gosql.DB - singletonDB *gosql.DB - migrationContext *base.MigrationContext - name string + connectionConfig *mysql.ConnectionConfig + db *gosql.DB + singletonDB *gosql.DB + migrationContext *base.MigrationContext + finishedMigrating int64 + name string } -func NewApplier() *Applier { +func NewApplier(migrationContext *base.MigrationContext) *Applier { return &Applier{ - connectionConfig: base.GetMigrationContext().ApplierConnectionConfig, - migrationContext: base.GetMigrationContext(), - name: "applier", + connectionConfig: migrationContext.ApplierConnectionConfig, + migrationContext: migrationContext, + finishedMigrating: 0, + name: "applier", } } func (this *Applier) InitDBConnections() (err error) { + applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(applierUri); err != nil { + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil { return err } - singletonApplierUri := fmt.Sprintf("%s?timeout=0", applierUri) - if this.singletonDB, _, err = sqlutils.GetDB(singletonApplierUri); err != nil { + singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri) + if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil { return err } this.singletonDB.SetMaxOpenConns(1) - version, err := base.ValidateConnection(this.db, this.connectionConfig, this.name) + version, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name) if err != nil { return err } - if _, err := base.ValidateConnection(this.singletonDB, this.connectionConfig, this.name); err != nil { + if _, err := base.ValidateConnection(this.singletonDB, this.connectionConfig, this.migrationContext, this.name); err != nil { return err } this.migrationContext.ApplierMySQLVersion = version if err := this.validateAndReadTimeZone(); err != nil { return err } - if impliedKey, err := mysql.GetInstanceKey(this.db); err != nil { - return err - } else { - this.connectionConfig.ImpliedKey = impliedKey + if !this.migrationContext.AliyunRDS && !this.migrationContext.GoogleCloudPlatform && !this.migrationContext.AzureMySQL { + if impliedKey, err := mysql.GetInstanceKey(this.db); err != nil { + return err + } else { + this.connectionConfig.ImpliedKey = impliedKey + } } if err := this.readTableColumns(); err != nil { return err } - log.Infof("Applier initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.ApplierMySQLVersion) + this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.ApplierMySQLVersion) return nil } @@ -85,14 +112,14 @@ func (this *Applier) validateAndReadTimeZone() error { return err } - log.Infof("will use time_zone='%s' on applier", this.migrationContext.ApplierTimeZone) + this.migrationContext.Log.Infof("will use time_zone='%s' on applier", this.migrationContext.ApplierTimeZone) return nil } // readTableColumns reads table columns on applier func (this *Applier) readTableColumns() (err error) { - log.Infof("Examining table structure on applier") - this.migrationContext.OriginalTableColumnsOnApplier, err = mysql.GetTableColumns(this.db, this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName) + this.migrationContext.Log.Infof("Examining table structure on applier") + this.migrationContext.OriginalTableColumnsOnApplier, _, err = mysql.GetTableColumns(this.db, this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName) if err != nil { return err } @@ -101,7 +128,6 @@ func (this *Applier) readTableColumns() (err error) { // showTableStatus returns the output of `show table status like '...'` command func (this *Applier) showTableStatus(tableName string) (rowMap sqlutils.RowMap) { - rowMap = nil query := fmt.Sprintf(`show /* gh-ost */ table status from %s like '%s'`, sql.EscapeName(this.migrationContext.DatabaseName), tableName) sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { rowMap = m @@ -133,7 +159,7 @@ func (this *Applier) ValidateOrDropExistingTables() error { } } if len(this.migrationContext.GetOldTableName()) > mysql.MaxTableNameLength { - log.Fatalf("--timestamp-old-table defined, but resulting table name (%s) is too long (only %d characters allowed)", this.migrationContext.GetOldTableName(), mysql.MaxTableNameLength) + this.migrationContext.Log.Fatalf("--timestamp-old-table defined, but resulting table name (%s) is too long (only %d characters allowed)", this.migrationContext.GetOldTableName(), mysql.MaxTableNameLength) } if this.tableExists(this.migrationContext.GetOldTableName()) { @@ -151,14 +177,14 @@ func (this *Applier) CreateGhostTable() error { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), ) - log.Infof("Creating ghost table %s.%s", + this.migrationContext.Log.Infof("Creating ghost table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetGhostTableName()), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Ghost table created") + this.migrationContext.Log.Infof("Ghost table created") return nil } @@ -167,17 +193,17 @@ func (this *Applier) AlterGhost() error { query := fmt.Sprintf(`alter /* gh-ost */ table %s.%s %s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetGhostTableName()), - this.migrationContext.AlterStatement, + this.migrationContext.AlterStatementOptions, ) - log.Infof("Altering ghost table %s.%s", + this.migrationContext.Log.Infof("Altering ghost table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetGhostTableName()), ) - log.Debugf("ALTER statement: %s", query) + this.migrationContext.Log.Debugf("ALTER statement: %s", query) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Ghost table altered") + this.migrationContext.Log.Infof("Ghost table altered") return nil } @@ -198,14 +224,14 @@ func (this *Applier) CreateChangelogTable() error { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) - log.Infof("Creating changelog table %s.%s", + this.migrationContext.Log.Infof("Creating changelog table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Changelog table created") + this.migrationContext.Log.Infof("Changelog table created") return nil } @@ -215,14 +241,14 @@ func (this *Applier) dropTable(tableName string) error { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(tableName), ) - log.Infof("Droppping table %s.%s", + this.migrationContext.Log.Infof("Dropping table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(tableName), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Table dropped") + this.migrationContext.Log.Infof("Table dropped") return nil } @@ -265,7 +291,7 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) - _, err := sqlutils.Exec(this.db, query, explicitId, hint, value) + _, err := sqlutils.ExecNoPrepare(this.db, query, explicitId, hint, value) return hint, err } @@ -289,7 +315,7 @@ func (this *Applier) InitiateHeartbeat() { if _, err := this.WriteChangelog("heartbeat", time.Now().Format(time.RFC3339Nano)); err != nil { numSuccessiveFailures++ if numSuccessiveFailures > this.migrationContext.MaxRetries() { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } } else { numSuccessiveFailures = 0 @@ -300,6 +326,9 @@ func (this *Applier) InitiateHeartbeat() { heartbeatTick := time.Tick(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) for range heartbeatTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } // Generally speaking, we would issue a goroutine, but I'd actually rather // have this block the loop rather than spam the master in the event something // goes wrong @@ -321,14 +350,14 @@ func (this *Applier) ExecuteThrottleQuery() (int64, error) { } var result int64 if err := this.db.QueryRow(throttleQuery).Scan(&result); err != nil { - return 0, log.Errore(err) + return 0, this.migrationContext.Log.Errore(err) } return result, nil } // ReadMigrationMinValues returns the minimum values to be iterated on rowcopy func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { - log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) + this.migrationContext.Log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &uniqueKey.Columns) if err != nil { return err @@ -343,13 +372,15 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { return err } } - log.Infof("Migration min values: [%s]", this.migrationContext.MigrationRangeMinValues) + this.migrationContext.Log.Infof("Migration min values: [%s]", this.migrationContext.MigrationRangeMinValues) + + err = rows.Err() return err } // ReadMigrationMaxValues returns the maximum values to be iterated on rowcopy func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { - log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) + this.migrationContext.Log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &uniqueKey.Columns) if err != nil { return err @@ -364,7 +395,9 @@ func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { return err } } - log.Infof("Migration max values: [%s]", this.migrationContext.MigrationRangeMaxValues) + this.migrationContext.Log.Infof("Migration max values: [%s]", this.migrationContext.MigrationRangeMaxValues) + + err = rows.Err() return err } @@ -382,7 +415,7 @@ func (this *Applier) ReadMigrationRangeValues() error { // CalculateNextIterationRangeEndValues reads the next-iteration-range-end unique key values, // which will be used for copying the next chunk of rows. Ir returns "false" if there is // no further chunk to work through, i.e. we're past the last chunk and are done with -// itrating the range (and this done with copying row chunks) +// iterating the range (and this done with copying row chunks) func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange bool, err error) { this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationIterationRangeMaxValues if this.migrationContext.MigrationIterationRangeMinValues == nil { @@ -417,12 +450,15 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo } hasFurtherRange = true } + if err = rows.Err(); err != nil { + return hasFurtherRange, err + } if hasFurtherRange { this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues return hasFurtherRange, nil } } - log.Debugf("Iteration complete: no further range to iterate") + this.migrationContext.Log.Debugf("Iteration complete: no further range to iterate") return hasFurtherRange, nil } @@ -454,10 +490,14 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected if err != nil { return nil, err } - sessionQuery := fmt.Sprintf(`SET - SESSION time_zone = '%s', - sql_mode = CONCAT(@@session.sql_mode, ',STRICT_ALL_TABLES') - `, this.migrationContext.ApplierTimeZone) + defer tx.Rollback() + sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone) + sqlModeAddendum := `,NO_AUTO_VALUE_ON_ZERO` + if !this.migrationContext.SkipStrictMode { + sqlModeAddendum = fmt.Sprintf("%s,STRICT_ALL_TABLES", sqlModeAddendum) + } + sessionQuery = fmt.Sprintf("%s, sql_mode = CONCAT(@@session.sql_mode, ',%s')", sessionQuery, sqlModeAddendum) + if _, err := tx.Exec(sessionQuery); err != nil { return nil, err } @@ -476,7 +516,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected } rowsAffected, _ = sqlResult.RowsAffected() duration = time.Since(startTime) - log.Debugf( + this.migrationContext.Log.Debugf( "Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d", this.migrationContext.MigrationIterationRangeMinValues, this.migrationContext.MigrationIterationRangeMaxValues, @@ -491,7 +531,7 @@ func (this *Applier) LockOriginalTable() error { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), ) - log.Infof("Locking %s.%s", + this.migrationContext.Log.Infof("Locking %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), ) @@ -499,18 +539,18 @@ func (this *Applier) LockOriginalTable() error { if _, err := sqlutils.ExecNoPrepare(this.singletonDB, query); err != nil { return err } - log.Infof("Table locked") + this.migrationContext.Log.Infof("Table locked") return nil } // UnlockTables makes tea. No wait, it unlocks tables. func (this *Applier) UnlockTables() error { query := `unlock /* gh-ost */ tables` - log.Infof("Unlocking tables") + this.migrationContext.Log.Infof("Unlocking tables") if _, err := sqlutils.ExecNoPrepare(this.singletonDB, query); err != nil { return err } - log.Infof("Tables unlocked") + this.migrationContext.Log.Infof("Tables unlocked") return nil } @@ -524,7 +564,7 @@ func (this *Applier) SwapTablesQuickAndBumpy() error { sql.EscapeName(this.migrationContext.OriginalTableName), sql.EscapeName(this.migrationContext.GetOldTableName()), ) - log.Infof("Renaming original table") + this.migrationContext.Log.Infof("Renaming original table") this.migrationContext.RenameTablesStartTime = time.Now() if _, err := sqlutils.ExecNoPrepare(this.singletonDB, query); err != nil { return err @@ -534,13 +574,13 @@ func (this *Applier) SwapTablesQuickAndBumpy() error { sql.EscapeName(this.migrationContext.GetGhostTableName()), sql.EscapeName(this.migrationContext.OriginalTableName), ) - log.Infof("Renaming ghost table") + this.migrationContext.Log.Infof("Renaming ghost table") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } this.migrationContext.RenameTablesEndTime = time.Now() - log.Infof("Tables renamed") + this.migrationContext.Log.Infof("Tables renamed") return nil } @@ -559,7 +599,7 @@ func (this *Applier) RenameTablesRollback() (renameError error) { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), ) - log.Infof("Renaming back both tables") + this.migrationContext.Log.Infof("Renaming back both tables") if _, err := sqlutils.ExecNoPrepare(this.db, query); err == nil { return nil } @@ -570,7 +610,7 @@ func (this *Applier) RenameTablesRollback() (renameError error) { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetGhostTableName()), ) - log.Infof("Renaming back to ghost table") + this.migrationContext.Log.Infof("Renaming back to ghost table") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { renameError = err } @@ -580,11 +620,11 @@ func (this *Applier) RenameTablesRollback() (renameError error) { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), ) - log.Infof("Renaming back to original table") + this.migrationContext.Log.Infof("Renaming back to original table") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { renameError = err } - return log.Errore(renameError) + return this.migrationContext.Log.Errore(renameError) } // StopSlaveIOThread is applicable with --test-on-replica; it stops the IO thread, duh. @@ -592,44 +632,44 @@ func (this *Applier) RenameTablesRollback() (renameError error) { // and have them written to the binary log, so that we can then read them via streamer. func (this *Applier) StopSlaveIOThread() error { query := `stop /* gh-ost */ slave io_thread` - log.Infof("Stopping replication IO thread") + this.migrationContext.Log.Infof("Stopping replication IO thread") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Replication IO thread stopped") + this.migrationContext.Log.Infof("Replication IO thread stopped") return nil } // StartSlaveIOThread is applicable with --test-on-replica func (this *Applier) StartSlaveIOThread() error { query := `start /* gh-ost */ slave io_thread` - log.Infof("Starting replication IO thread") + this.migrationContext.Log.Infof("Starting replication IO thread") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Replication IO thread started") + this.migrationContext.Log.Infof("Replication IO thread started") return nil } // StartSlaveSQLThread is applicable with --test-on-replica func (this *Applier) StopSlaveSQLThread() error { query := `stop /* gh-ost */ slave sql_thread` - log.Infof("Verifying SQL thread is stopped") + this.migrationContext.Log.Infof("Verifying SQL thread is stopped") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("SQL thread stopped") + this.migrationContext.Log.Infof("SQL thread stopped") return nil } // StartSlaveSQLThread is applicable with --test-on-replica func (this *Applier) StartSlaveSQLThread() error { query := `start /* gh-ost */ slave sql_thread` - log.Infof("Verifying SQL thread is running") + this.migrationContext.Log.Infof("Verifying SQL thread is running") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("SQL thread started") + this.migrationContext.Log.Infof("SQL thread started") return nil } @@ -646,7 +686,7 @@ func (this *Applier) StopReplication() error { if err != nil { return err } - log.Infof("Replication IO thread at %+v. SQL thread is at %+v", *readBinlogCoordinates, *executeBinlogCoordinates) + this.migrationContext.Log.Infof("Replication IO thread at %+v. SQL thread is at %+v", *readBinlogCoordinates, *executeBinlogCoordinates) return nil } @@ -658,7 +698,7 @@ func (this *Applier) StartReplication() error { if err := this.StartSlaveSQLThread(); err != nil { return err } - log.Infof("Replication started") + this.migrationContext.Log.Infof("Replication started") return nil } @@ -672,7 +712,7 @@ func (this *Applier) ExpectUsedLock(sessionId int64) error { var result int64 query := `select is_used_lock(?)` lockName := this.GetSessionLockName(sessionId) - log.Infof("Checking session lock: %s", lockName) + this.migrationContext.Log.Infof("Checking session lock: %s", lockName) if err := this.db.QueryRow(query, lockName).Scan(&result); err != nil || result != sessionId { return fmt.Errorf("Session lock %s expected to be found but wasn't", lockName) } @@ -707,7 +747,7 @@ func (this *Applier) ExpectProcess(sessionId int64, stateHint, infoHint string) // DropAtomicCutOverSentryTableIfExists checks if the "old" table name // happens to be a cut-over magic table; if so, it drops it. func (this *Applier) DropAtomicCutOverSentryTableIfExists() error { - log.Infof("Looking for magic cut-over table") + this.migrationContext.Log.Infof("Looking for magic cut-over table") tableName := this.migrationContext.GetOldTableName() rowMap := this.showTableStatus(tableName) if rowMap == nil { @@ -717,7 +757,7 @@ func (this *Applier) DropAtomicCutOverSentryTableIfExists() error { if rowMap["Comment"].String != atomicCutOverMagicHint { return fmt.Errorf("Expected magic comment on %s, did not find it", tableName) } - log.Infof("Dropping magic cut-over table") + this.migrationContext.Log.Infof("Dropping magic cut-over table") return this.dropTable(tableName) } @@ -737,20 +777,20 @@ func (this *Applier) CreateAtomicCutOverSentryTable() error { this.migrationContext.TableEngine, atomicCutOverMagicHint, ) - log.Infof("Creating magic cut-over table %s.%s", + this.migrationContext.Log.Infof("Creating magic cut-over table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(tableName), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Magic cut-over table created") + this.migrationContext.Log.Infof("Magic cut-over table created") return nil } // AtomicCutOverMagicLock -func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocked chan<- error, okToUnlockTable <-chan bool, tableUnlocked chan<- error) error { +func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocked chan<- error, okToUnlockTable <-chan bool, tableUnlocked chan<- error, dropCutOverSentryTableOnce *sync.Once) error { tx, err := this.db.Begin() if err != nil { tableLocked <- err @@ -773,7 +813,7 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke lockResult := 0 query := `select get_lock(?, 0)` lockName := this.GetSessionLockName(sessionId) - log.Infof("Grabbing voluntary lock: %s", lockName) + this.migrationContext.Log.Infof("Grabbing voluntary lock: %s", lockName) if err := tx.QueryRow(query, lockName).Scan(&lockResult); err != nil || lockResult != 1 { err := fmt.Errorf("Unable to acquire lock %s", lockName) tableLocked <- err @@ -781,7 +821,7 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke } tableLockTimeoutSeconds := this.migrationContext.CutOverLockTimeoutSeconds * 2 - log.Infof("Setting LOCK timeout as %d seconds", tableLockTimeoutSeconds) + this.migrationContext.Log.Infof("Setting LOCK timeout as %d seconds", tableLockTimeoutSeconds) query = fmt.Sprintf(`set session lock_wait_timeout:=%d`, tableLockTimeoutSeconds) if _, err := tx.Exec(query); err != nil { tableLocked <- err @@ -799,7 +839,7 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetOldTableName()), ) - log.Infof("Locking %s.%s, %s.%s", + this.migrationContext.Log.Infof("Locking %s.%s, %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), sql.EscapeName(this.migrationContext.DatabaseName), @@ -810,7 +850,7 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke tableLocked <- err return err } - log.Infof("Tables locked") + this.migrationContext.Log.Infof("Tables locked") tableLocked <- nil // No error. // From this point on, we are committed to UNLOCK TABLES. No matter what happens, @@ -819,22 +859,25 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke // The cut-over phase will proceed to apply remaining backlog onto ghost table, // and issue RENAME. We wait here until told to proceed. <-okToUnlockTable - log.Infof("Will now proceed to drop magic table and unlock tables") + this.migrationContext.Log.Infof("Will now proceed to drop magic table and unlock tables") // The magic table is here because we locked it. And we are the only ones allowed to drop it. // And in fact, we will: - log.Infof("Dropping magic cut-over table") + this.migrationContext.Log.Infof("Dropping magic cut-over table") query = fmt.Sprintf(`drop /* gh-ost */ table if exists %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetOldTableName()), ) - if _, err := tx.Exec(query); err != nil { - log.Errore(err) - // We DO NOT return here because we must `UNLOCK TABLES`! - } + + dropCutOverSentryTableOnce.Do(func() { + if _, err := tx.Exec(query); err != nil { + this.migrationContext.Log.Errore(err) + // We DO NOT return here because we must `UNLOCK TABLES`! + } + }) // Tables still locked - log.Infof("Releasing lock from %s.%s, %s.%s", + this.migrationContext.Log.Infof("Releasing lock from %s.%s, %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), sql.EscapeName(this.migrationContext.DatabaseName), @@ -843,9 +886,9 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke query = `unlock tables` if _, err := tx.Exec(query); err != nil { tableUnlocked <- err - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } - log.Infof("Tables unlocked") + this.migrationContext.Log.Infof("Tables unlocked") tableUnlocked <- nil return nil } @@ -867,7 +910,7 @@ func (this *Applier) AtomicCutoverRename(sessionIdChan chan int64, tablesRenamed } sessionIdChan <- sessionId - log.Infof("Setting RENAME timeout as %d seconds", this.migrationContext.CutOverLockTimeoutSeconds) + this.migrationContext.Log.Infof("Setting RENAME timeout as %d seconds", this.migrationContext.CutOverLockTimeoutSeconds) query := fmt.Sprintf(`set session lock_wait_timeout:=%d`, this.migrationContext.CutOverLockTimeoutSeconds) if _, err := tx.Exec(query); err != nil { return err @@ -883,13 +926,13 @@ func (this *Applier) AtomicCutoverRename(sessionIdChan chan int64, tablesRenamed sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), ) - log.Infof("Issuing and expecting this to block: %s", query) + this.migrationContext.Log.Infof("Issuing and expecting this to block: %s", query) if _, err := tx.Exec(query); err != nil { tablesRenamed <- err - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } tablesRenamed <- nil - log.Infof("Tables renamed") + this.migrationContext.Log.Infof("Tables renamed") return nil } @@ -901,81 +944,52 @@ func (this *Applier) ShowStatusVariable(variableName string) (result int64, err return result, nil } +// updateModifiesUniqueKeyColumns checks whether a UPDATE DML event actually +// modifies values of the migration's unique key (the iterated key). This will call +// for special handling. +func (this *Applier) updateModifiesUniqueKeyColumns(dmlEvent *binlog.BinlogDMLEvent) (modifiedColumn string, isModified bool) { + for _, column := range this.migrationContext.UniqueKey.Columns.Columns() { + tableOrdinal := this.migrationContext.OriginalTableColumns.Ordinals[column.Name] + whereColumnValue := dmlEvent.WhereColumnValues.AbstractValues()[tableOrdinal] + newColumnValue := dmlEvent.NewColumnValues.AbstractValues()[tableOrdinal] + if newColumnValue != whereColumnValue { + return column.Name, true + } + } + return "", false +} + // buildDMLEventQuery creates a query to operate on the ghost table, based on an intercepted binlog // event entry on the original table. -func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (query string, args []interface{}, rowsDelta int64, err error) { +func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (results [](*dmlBuildResult)) { switch dmlEvent.DML { case binlog.DeleteDML: { query, uniqueKeyArgs, err := sql.BuildDMLDeleteQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.WhereColumnValues.AbstractValues()) - return query, uniqueKeyArgs, -1, err + return append(results, newDmlBuildResult(query, uniqueKeyArgs, -1, err)) } case binlog.InsertDML: { query, sharedArgs, err := sql.BuildDMLInsertQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, this.migrationContext.MappedSharedColumns, dmlEvent.NewColumnValues.AbstractValues()) - return query, sharedArgs, 1, err + return append(results, newDmlBuildResult(query, sharedArgs, 1, err)) } case binlog.UpdateDML: { + if _, isModified := this.updateModifiesUniqueKeyColumns(dmlEvent); isModified { + dmlEvent.DML = binlog.DeleteDML + results = append(results, this.buildDMLEventQuery(dmlEvent)...) + dmlEvent.DML = binlog.InsertDML + results = append(results, this.buildDMLEventQuery(dmlEvent)...) + return results + } query, sharedArgs, uniqueKeyArgs, err := sql.BuildDMLUpdateQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, this.migrationContext.MappedSharedColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) + args := sqlutils.Args() args = append(args, sharedArgs...) args = append(args, uniqueKeyArgs...) - return query, args, 0, err + return append(results, newDmlBuildResult(query, args, 0, err)) } } - return "", args, 0, fmt.Errorf("Unknown dml event type: %+v", dmlEvent.DML) -} - -// ApplyDMLEventQuery writes an entry to the ghost table, in response to an intercepted -// original-table binlog event -func (this *Applier) ApplyDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) error { - query, args, rowDelta, err := this.buildDMLEventQuery(dmlEvent) - if err != nil { - return err - } - // TODO The below is in preparation for transactional writes on the ghost tables. - // Such writes would be, for example: - // - prepended with sql_mode setup - // - prepended with time zone setup - // - prepended with SET SQL_LOG_BIN=0 - // - prepended with SET FK_CHECKS=0 - // etc. - // - // a known problem: https://github.com/golang/go/issues/9373 -- bitint unsigned values, not supported in database/sql - // is solved by silently converting unsigned bigints to string values. - // - - err = func() error { - tx, err := this.db.Begin() - if err != nil { - return err - } - sessionQuery := `SET - SESSION time_zone = '+00:00', - sql_mode = CONCAT(@@session.sql_mode, ',STRICT_ALL_TABLES') - ` - if _, err := tx.Exec(sessionQuery); err != nil { - return err - } - if _, err := tx.Exec(query, args...); err != nil { - return err - } - if err := tx.Commit(); err != nil { - return err - } - return nil - }() - - if err != nil { - err = fmt.Errorf("%s; query=%s; args=%+v", err.Error(), query, args) - return log.Errore(err) - } - // no error - atomic.AddInt64(&this.migrationContext.TotalDMLEventsApplied, 1) - if this.migrationContext.CountTableRows { - atomic.AddInt64(&this.migrationContext.RowsDeltaEstimate, rowDelta) - } - return nil + return append(results, newDmlBuildResultError(fmt.Errorf("Unknown dml event type: %+v", dmlEvent.DML))) } // ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table @@ -994,23 +1008,28 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) return err } - sessionQuery := `SET - SESSION time_zone = '+00:00', - sql_mode = CONCAT(@@session.sql_mode, ',STRICT_ALL_TABLES') - ` + sessionQuery := "SET SESSION time_zone = '+00:00'" + + sqlModeAddendum := `,NO_AUTO_VALUE_ON_ZERO` + if !this.migrationContext.SkipStrictMode { + sqlModeAddendum = fmt.Sprintf("%s,STRICT_ALL_TABLES", sqlModeAddendum) + } + sessionQuery = fmt.Sprintf("%s, sql_mode = CONCAT(@@session.sql_mode, ',%s')", sessionQuery, sqlModeAddendum) + if _, err := tx.Exec(sessionQuery); err != nil { return rollback(err) } for _, dmlEvent := range dmlEvents { - query, args, rowDelta, err := this.buildDMLEventQuery(dmlEvent) - if err != nil { - return rollback(err) + for _, buildResult := range this.buildDMLEventQuery(dmlEvent) { + if buildResult.err != nil { + return rollback(buildResult.err) + } + if _, err := tx.Exec(buildResult.query, buildResult.args...); err != nil { + err = fmt.Errorf("%s; query=%s; args=%+v", err.Error(), buildResult.query, buildResult.args) + return rollback(err) + } + totalDelta += buildResult.rowsDelta } - if _, err := tx.Exec(query, args...); err != nil { - err = fmt.Errorf("%s; query=%s; args=%+v", err.Error(), query, args) - return rollback(err) - } - totalDelta += rowDelta } if err := tx.Commit(); err != nil { return err @@ -1019,13 +1038,20 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) }() if err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } // no error atomic.AddInt64(&this.migrationContext.TotalDMLEventsApplied, int64(len(dmlEvents))) if this.migrationContext.CountTableRows { atomic.AddInt64(&this.migrationContext.RowsDeltaEstimate, totalDelta) } - log.Debugf("ApplyDMLEventQueries() applied %d events in one transaction", len(dmlEvents)) + this.migrationContext.Log.Debugf("ApplyDMLEventQueries() applied %d events in one transaction", len(dmlEvents)) return nil } + +func (this *Applier) Teardown() { + this.migrationContext.Log.Debugf("Tearing down...") + this.db.Close() + this.singletonDB.Close() + atomic.StoreInt64(&this.finishedMigrating, 1) +} diff --git a/go/logic/hooks.go b/go/logic/hooks.go index 58825ee..fa5011e 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -37,9 +37,9 @@ type HooksExecutor struct { migrationContext *base.MigrationContext } -func NewHooksExecutor() *HooksExecutor { +func NewHooksExecutor(migrationContext *base.MigrationContext) *HooksExecutor { return &HooksExecutor{ - migrationContext: base.GetMigrationContext(), + migrationContext: migrationContext, } } @@ -63,7 +63,12 @@ func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) [ env = append(env, fmt.Sprintf("GH_OST_MIGRATED_HOST=%s", this.migrationContext.GetApplierHostname())) env = append(env, fmt.Sprintf("GH_OST_INSPECTED_HOST=%s", this.migrationContext.GetInspectorHostname())) env = append(env, fmt.Sprintf("GH_OST_EXECUTING_HOST=%s", this.migrationContext.Hostname)) + env = append(env, fmt.Sprintf("GH_OST_INSPECTED_LAG=%f", this.migrationContext.GetCurrentLagDuration().Seconds())) + env = append(env, fmt.Sprintf("GH_OST_PROGRESS=%f", this.migrationContext.GetProgressPct())) env = append(env, fmt.Sprintf("GH_OST_HOOKS_HINT=%s", this.migrationContext.HooksHintMessage)) + env = append(env, fmt.Sprintf("GH_OST_HOOKS_HINT_OWNER=%s", this.migrationContext.HooksHintOwner)) + env = append(env, fmt.Sprintf("GH_OST_HOOKS_HINT_TOKEN=%s", this.migrationContext.HooksHintToken)) + env = append(env, fmt.Sprintf("GH_OST_DRY_RUN=%t", this.migrationContext.Noop)) for _, variable := range extraVariables { env = append(env, variable) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 1c18642..2c1846b 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -17,7 +17,6 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/outbrain/golib/log" "github.com/outbrain/golib/sqlutils" ) @@ -26,32 +25,41 @@ const startSlavePostWaitMilliseconds = 500 * time.Millisecond // Inspector reads data from the read-MySQL-server (typically a replica, but can be the master) // It is used for gaining initial status and structure, and later also follow up on progress and changelog type Inspector struct { - connectionConfig *mysql.ConnectionConfig - db *gosql.DB - migrationContext *base.MigrationContext - name string + connectionConfig *mysql.ConnectionConfig + db *gosql.DB + informationSchemaDb *gosql.DB + migrationContext *base.MigrationContext + name string } -func NewInspector() *Inspector { +func NewInspector(migrationContext *base.MigrationContext) *Inspector { return &Inspector{ - connectionConfig: base.GetMigrationContext().InspectorConnectionConfig, - migrationContext: base.GetMigrationContext(), - name: "inspector", + connectionConfig: migrationContext.InspectorConnectionConfig, + migrationContext: migrationContext, + name: "inspector", } } func (this *Inspector) InitDBConnections() (err error) { inspectorUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(inspectorUri); err != nil { + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, inspectorUri); err != nil { return err } + + informationSchemaUri := this.connectionConfig.GetDBUri("information_schema") + if this.informationSchemaDb, _, err = mysql.GetDB(this.migrationContext.Uuid, informationSchemaUri); err != nil { + return err + } + if err := this.validateConnection(); err != nil { return err } - if impliedKey, err := mysql.GetInstanceKey(this.db); err != nil { - return err - } else { - this.connectionConfig.ImpliedKey = impliedKey + if !this.migrationContext.AliyunRDS && !this.migrationContext.GoogleCloudPlatform && !this.migrationContext.AzureMySQL { + if impliedKey, err := mysql.GetInstanceKey(this.db); err != nil { + return err + } else { + this.connectionConfig.ImpliedKey = impliedKey + } } if err := this.validateGrants(); err != nil { return err @@ -62,7 +70,7 @@ func (this *Inspector) InitDBConnections() (err error) { if err := this.applyBinlogFormat(); err != nil { return err } - log.Infof("Inspector initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.InspectorMySQLVersion) + this.migrationContext.Log.Infof("Inspector initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.InspectorMySQLVersion) return nil } @@ -82,24 +90,24 @@ func (this *Inspector) ValidateOriginalTable() (err error) { return nil } -func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { +func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, virtualColumns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { uniqueKeys, err = this.getCandidateUniqueKeys(tableName) if err != nil { - return columns, uniqueKeys, err + return columns, virtualColumns, uniqueKeys, err } if len(uniqueKeys) == 0 { - return columns, uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") + return columns, virtualColumns, uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") } - columns, err = mysql.GetTableColumns(this.db, this.migrationContext.DatabaseName, tableName) + columns, virtualColumns, err = mysql.GetTableColumns(this.db, this.migrationContext.DatabaseName, tableName) if err != nil { - return columns, uniqueKeys, err + return columns, virtualColumns, uniqueKeys, err } - return columns, uniqueKeys, nil + return columns, virtualColumns, uniqueKeys, nil } func (this *Inspector) InspectOriginalTable() (err error) { - this.migrationContext.OriginalTableColumns, this.migrationContext.OriginalTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.OriginalTableName) + this.migrationContext.OriginalTableColumns, this.migrationContext.OriginalTableVirtualColumns, this.migrationContext.OriginalTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.OriginalTableName) if err != nil { return err } @@ -115,7 +123,7 @@ func (this *Inspector) inspectOriginalAndGhostTables() (err error) { return fmt.Errorf("It seems like table structure is not identical between master and replica. This scenario is not supported.") } - this.migrationContext.GhostTableColumns, this.migrationContext.GhostTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.GetGhostTableName()) + this.migrationContext.GhostTableColumns, this.migrationContext.GhostTableVirtualColumns, this.migrationContext.GhostTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.GetGhostTableName()) if err != nil { return err } @@ -130,14 +138,14 @@ func (this *Inspector) inspectOriginalAndGhostTables() (err error) { switch column.Type { case sql.FloatColumnType: { - log.Warning("Will not use %+v as shared key due to FLOAT data type", sharedUniqueKey.Name) + this.migrationContext.Log.Warning("Will not use %+v as shared key due to FLOAT data type", sharedUniqueKey.Name) uniqueKeyIsValid = false } case sql.JSONColumnType: { // Noteworthy that at this time MySQL does not allow JSON indexing anyhow, but this code // will remain in place to potentially handle the future case where JSON is supported in indexes. - log.Warning("Will not use %+v as shared key due to JSON data type", sharedUniqueKey.Name) + this.migrationContext.Log.Warning("Will not use %+v as shared key due to JSON data type", sharedUniqueKey.Name) uniqueKeyIsValid = false } } @@ -150,29 +158,23 @@ func (this *Inspector) inspectOriginalAndGhostTables() (err error) { if this.migrationContext.UniqueKey == nil { return fmt.Errorf("No shared unique key can be found after ALTER! Bailing out") } - log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name) + this.migrationContext.Log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name) if this.migrationContext.UniqueKey.HasNullable { if this.migrationContext.NullableUniqueKeyAllowed { - log.Warningf("Chosen key (%s) has nullable columns. You have supplied with --allow-nullable-unique-key and so this migration proceeds. As long as there aren't NULL values in this key's column, migration should be fine. NULL values will corrupt migration's data", this.migrationContext.UniqueKey) + this.migrationContext.Log.Warningf("Chosen key (%s) has nullable columns. You have supplied with --allow-nullable-unique-key and so this migration proceeds. As long as there aren't NULL values in this key's column, migration should be fine. NULL values will corrupt migration's data", this.migrationContext.UniqueKey) } else { return fmt.Errorf("Chosen key (%s) has nullable columns. Bailing out. To force this operation to continue, supply --allow-nullable-unique-key flag. Only do so if you are certain there are no actual NULL values in this key. As long as there aren't, migration should be fine. NULL values in columns of this key will corrupt migration's data", this.migrationContext.UniqueKey) } } - if !this.migrationContext.UniqueKey.IsPrimary() { - if this.migrationContext.OriginalBinlogRowImage != "FULL" { - return fmt.Errorf("binlog_row_image is '%s' and chosen key is %s, which is not the primary key. This operation cannot proceed. You may `set global binlog_row_image='full'` and try again", this.migrationContext.OriginalBinlogRowImage, this.migrationContext.UniqueKey) - } - } - this.migrationContext.SharedColumns, this.migrationContext.MappedSharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns, this.migrationContext.ColumnRenameMap) - log.Infof("Shared columns are %s", this.migrationContext.SharedColumns) + this.migrationContext.SharedColumns, this.migrationContext.MappedSharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns, this.migrationContext.OriginalTableVirtualColumns, this.migrationContext.GhostTableVirtualColumns, this.migrationContext.ColumnRenameMap) + this.migrationContext.Log.Infof("Shared columns are %s", this.migrationContext.SharedColumns) // By fact that a non-empty unique key exists we also know the shared columns are non-empty // This additional step looks at which columns are unsigned. We could have merged this within // the `getTableColumns()` function, but it's a later patch and introduces some complexity; I feel // comfortable in doing this as a separate step. - this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns) - this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &this.migrationContext.UniqueKey.Columns) + this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns) this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.GhostTableColumns, this.migrationContext.MappedSharedColumns) for i := range this.migrationContext.SharedColumns.Columns() { @@ -198,13 +200,13 @@ func (this *Inspector) validateConnection() error { return fmt.Errorf("MySQL replication length limited to 32 characters. See https://dev.mysql.com/doc/refman/5.7/en/assigning-passwords.html") } - version, err := base.ValidateConnection(this.db, this.connectionConfig, this.name) + version, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name) this.migrationContext.InspectorMySQLVersion = version return err } // validateGrants verifies the user by which we're executing has necessary grants -// to do its thang. +// to do its thing. func (this *Inspector) validateGrants() error { query := `show /* gh-ost */ grants for current_user()` foundAll := false @@ -231,6 +233,9 @@ func (this *Inspector) validateGrants() error { if strings.Contains(grant, fmt.Sprintf("GRANT ALL PRIVILEGES ON `%s`.*", this.migrationContext.DatabaseName)) { foundDBAll = true } + if strings.Contains(grant, fmt.Sprintf("GRANT ALL PRIVILEGES ON `%s`.*", strings.Replace(this.migrationContext.DatabaseName, "_", "\\_", -1))) { + foundDBAll = true + } if base.StringContainsAll(grant, `ALTER`, `CREATE`, `DELETE`, `DROP`, `INDEX`, `INSERT`, `LOCK TABLES`, `SELECT`, `TRIGGER`, `UPDATE`, ` ON *.*`) { foundDBAll = true } @@ -246,27 +251,27 @@ func (this *Inspector) validateGrants() error { this.migrationContext.HasSuperPrivilege = foundSuper if foundAll { - log.Infof("User has ALL privileges") + this.migrationContext.Log.Infof("User has ALL privileges") return nil } if foundSuper && foundReplicationSlave && foundDBAll { - log.Infof("User has SUPER, REPLICATION SLAVE privileges, and has ALL privileges on %s.*", sql.EscapeName(this.migrationContext.DatabaseName)) + this.migrationContext.Log.Infof("User has SUPER, REPLICATION SLAVE privileges, and has ALL privileges on %s.*", sql.EscapeName(this.migrationContext.DatabaseName)) return nil } if foundReplicationClient && foundReplicationSlave && foundDBAll { - log.Infof("User has REPLICATION CLIENT, REPLICATION SLAVE privileges, and has ALL privileges on %s.*", sql.EscapeName(this.migrationContext.DatabaseName)) + this.migrationContext.Log.Infof("User has REPLICATION CLIENT, REPLICATION SLAVE privileges, and has ALL privileges on %s.*", sql.EscapeName(this.migrationContext.DatabaseName)) return nil } - log.Debugf("Privileges: Super: %t, REPLICATION CLIENT: %t, REPLICATION SLAVE: %t, ALL on *.*: %t, ALL on %s.*: %t", foundSuper, foundReplicationClient, foundReplicationSlave, foundAll, sql.EscapeName(this.migrationContext.DatabaseName), foundDBAll) - return log.Errorf("User has insufficient privileges for migration. Needed: SUPER|REPLICATION CLIENT, REPLICATION SLAVE and ALL on %s.*", sql.EscapeName(this.migrationContext.DatabaseName)) + this.migrationContext.Log.Debugf("Privileges: Super: %t, REPLICATION CLIENT: %t, REPLICATION SLAVE: %t, ALL on *.*: %t, ALL on %s.*: %t", foundSuper, foundReplicationClient, foundReplicationSlave, foundAll, sql.EscapeName(this.migrationContext.DatabaseName), foundDBAll) + return this.migrationContext.Log.Errorf("User has insufficient privileges for migration. Needed: SUPER|REPLICATION CLIENT, REPLICATION SLAVE and ALL on %s.*", sql.EscapeName(this.migrationContext.DatabaseName)) } // restartReplication is required so that we are _certain_ the binlog format and // row image settings have actually been applied to the replication thread. -// It is entriely possible, for example, that the replication is using 'STATEMENT' +// It is entirely possible, for example, that the replication is using 'STATEMENT' // binlog format even as the variable says 'ROW' func (this *Inspector) restartReplication() error { - log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + this.migrationContext.Log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) masterKey, _ := mysql.GetMasterKeyFromSlaveStatus(this.connectionConfig) if masterKey == nil { @@ -285,7 +290,7 @@ func (this *Inspector) restartReplication() error { } time.Sleep(startSlavePostWaitMilliseconds) - log.Debugf("Replication restarted") + this.migrationContext.Log.Debugf("Replication restarted") return nil } @@ -305,7 +310,7 @@ func (this *Inspector) applyBinlogFormat() error { if err := this.restartReplication(); err != nil { return err } - log.Debugf("'ROW' binlog format applied") + this.migrationContext.Log.Debugf("'ROW' binlog format applied") return nil } // We already have RBR, no explicit switch @@ -343,7 +348,7 @@ func (this *Inspector) validateBinlogs() error { if countReplicas > 0 { return fmt.Errorf("%s:%d has %s binlog_format, but I'm too scared to change it to ROW because it has replicas. Bailing out", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogFormat) } - log.Infof("%s:%d has %s binlog_format. I will change it to ROW, and will NOT change it back, even in the event of failure.", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogFormat) + this.migrationContext.Log.Infof("%s:%d has %s binlog_format. I will change it to ROW, and will NOT change it back, even in the event of failure.", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogFormat) } query = `select @@global.binlog_row_image` if err := this.db.QueryRow(query).Scan(&this.migrationContext.OriginalBinlogRowImage); err != nil { @@ -351,8 +356,11 @@ func (this *Inspector) validateBinlogs() error { this.migrationContext.OriginalBinlogRowImage = "FULL" } this.migrationContext.OriginalBinlogRowImage = strings.ToUpper(this.migrationContext.OriginalBinlogRowImage) + if this.migrationContext.OriginalBinlogRowImage != "FULL" { + return fmt.Errorf("%s:%d has '%s' binlog_row_image, and only 'FULL' is supported. This operation cannot proceed. You may `set global binlog_row_image='full'` and try again", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port, this.migrationContext.OriginalBinlogRowImage) + } - log.Infof("binary logs validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + this.migrationContext.Log.Infof("binary logs validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) return nil } @@ -365,12 +373,12 @@ func (this *Inspector) validateLogSlaveUpdates() error { } if logSlaveUpdates { - log.Infof("log_slave_updates validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + this.migrationContext.Log.Infof("log_slave_updates validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) return nil } if this.migrationContext.IsTungsten { - log.Warningf("log_slave_updates not found on %s:%d, but --tungsten provided, so I'm proceeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + this.migrationContext.Log.Warningf("log_slave_updates not found on %s:%d, but --tungsten provided, so I'm proceeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) return nil } @@ -379,7 +387,7 @@ func (this *Inspector) validateLogSlaveUpdates() error { } if this.migrationContext.InspectorIsAlsoApplier() { - log.Warningf("log_slave_updates not found on %s:%d, but executing directly on master, so I'm proceeeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + this.migrationContext.Log.Warningf("log_slave_updates not found on %s:%d, but executing directly on master, so I'm proceeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) return nil } @@ -406,17 +414,17 @@ func (this *Inspector) validateTable() error { return err } if !tableFound { - return log.Errorf("Cannot find table %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + return this.migrationContext.Log.Errorf("Cannot find table %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) } - log.Infof("Table found. Engine=%s", this.migrationContext.TableEngine) - log.Debugf("Estimated number of rows via STATUS: %d", this.migrationContext.RowsEstimate) + this.migrationContext.Log.Infof("Table found. Engine=%s", this.migrationContext.TableEngine) + this.migrationContext.Log.Debugf("Estimated number of rows via STATUS: %d", this.migrationContext.RowsEstimate) return nil } // validateTableForeignKeys makes sure no foreign keys exist on the migrated table func (this *Inspector) validateTableForeignKeys(allowChildForeignKeys bool) error { if this.migrationContext.SkipForeignKeyChecks { - log.Warning("--skip-foreign-key-checks provided: will not check for foreign keys") + this.migrationContext.Log.Warning("--skip-foreign-key-checks provided: will not check for foreign keys") return nil } query := ` @@ -450,16 +458,16 @@ func (this *Inspector) validateTableForeignKeys(allowChildForeignKeys bool) erro return err } if numParentForeignKeys > 0 { - return log.Errorf("Found %d parent-side foreign keys on %s.%s. Parent-side foreign keys are not supported. Bailing out", numParentForeignKeys, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + return this.migrationContext.Log.Errorf("Found %d parent-side foreign keys on %s.%s. Parent-side foreign keys are not supported. Bailing out", numParentForeignKeys, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) } if numChildForeignKeys > 0 { if allowChildForeignKeys { - log.Debugf("Foreign keys found and will be dropped, as per given --discard-foreign-keys flag") + this.migrationContext.Log.Debugf("Foreign keys found and will be dropped, as per given --discard-foreign-keys flag") return nil } - return log.Errorf("Found %d child-side foreign keys on %s.%s. Child-side foreign keys are not supported. Bailing out", numChildForeignKeys, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + return this.migrationContext.Log.Errorf("Found %d child-side foreign keys on %s.%s. Child-side foreign keys are not supported. Bailing out", numChildForeignKeys, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) } - log.Debugf("Validated no foreign keys exist on table") + this.migrationContext.Log.Debugf("Validated no foreign keys exist on table") return nil } @@ -485,9 +493,9 @@ func (this *Inspector) validateTableTriggers() error { return err } if numTriggers > 0 { - return log.Errorf("Found triggers on %s.%s. Triggers are not supported at this time. Bailing out", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + return this.migrationContext.Log.Errorf("Found triggers on %s.%s. Triggers are not supported at this time. Bailing out", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) } - log.Debugf("Validated no triggers exist on table") + this.migrationContext.Log.Debugf("Validated no triggers exist on table") return nil } @@ -507,9 +515,9 @@ func (this *Inspector) estimateTableRowsViaExplain() error { return err } if !outputFound { - return log.Errorf("Cannot run EXPLAIN on %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + return this.migrationContext.Log.Errorf("Cannot run EXPLAIN on %s.%s!", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) } - log.Infof("Estimated number of rows via EXPLAIN: %d", this.migrationContext.RowsEstimate) + this.migrationContext.Log.Infof("Estimated number of rows via EXPLAIN: %d", this.migrationContext.RowsEstimate) return nil } @@ -518,7 +526,7 @@ func (this *Inspector) CountTableRows() error { atomic.StoreInt64(&this.migrationContext.CountingRowsFlag, 1) defer atomic.StoreInt64(&this.migrationContext.CountingRowsFlag, 0) - log.Infof("As instructed, I'm issuing a SELECT COUNT(*) on the table. This may take a while") + this.migrationContext.Log.Infof("As instructed, I'm issuing a SELECT COUNT(*) on the table. This may take a while") query := fmt.Sprintf(`select /* gh-ost */ count(*) as rows from %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) var rowsEstimate int64 @@ -528,7 +536,7 @@ func (this *Inspector) CountTableRows() error { atomic.StoreInt64(&this.migrationContext.RowsEstimate, rowsEstimate) this.migrationContext.UsedRowsEstimateMethod = base.CountRowsEstimate - log.Infof("Exact number of rows via COUNT: %d", rowsEstimate) + this.migrationContext.Log.Infof("Exact number of rows via COUNT: %d", rowsEstimate) return nil } @@ -547,44 +555,35 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { columnName := m.GetString("COLUMN_NAME") columnType := m.GetString("COLUMN_TYPE") - if strings.Contains(columnType, "unsigned") { - for _, columnsList := range columnsLists { - columnsList.SetUnsigned(columnName) + for _, columnsList := range columnsLists { + column := columnsList.GetColumn(columnName) + if column == nil { + continue } - } - if strings.Contains(columnType, "mediumint") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.MediumIntColumnType + + if strings.Contains(columnType, "unsigned") { + column.IsUnsigned = true } - } - if strings.Contains(columnType, "timestamp") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.TimestampColumnType + if strings.Contains(columnType, "mediumint") { + column.Type = sql.MediumIntColumnType } - } - if strings.Contains(columnType, "datetime") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.DateTimeColumnType + if strings.Contains(columnType, "timestamp") { + column.Type = sql.TimestampColumnType } - } - if strings.Contains(columnType, "json") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.JSONColumnType + if strings.Contains(columnType, "datetime") { + column.Type = sql.DateTimeColumnType } - } - if strings.Contains(columnType, "float") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.FloatColumnType + if strings.Contains(columnType, "json") { + column.Type = sql.JSONColumnType } - } - if strings.HasPrefix(columnType, "enum") { - for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.EnumColumnType + if strings.Contains(columnType, "float") { + column.Type = sql.FloatColumnType } - } - if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { - for _, columnsList := range columnsLists { - columnsList.SetCharset(columnName, charset) + if strings.HasPrefix(columnType, "enum") { + column.Type = sql.EnumColumnType + } + if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { + column.Charset = charset } } return nil @@ -624,8 +623,6 @@ func (this *Inspector) getCandidateUniqueKeys(tableName string) (uniqueKeys [](* GROUP BY TABLE_SCHEMA, TABLE_NAME, INDEX_NAME ) AS UNIQUES ON ( - COLUMNS.TABLE_SCHEMA = UNIQUES.TABLE_SCHEMA AND - COLUMNS.TABLE_NAME = UNIQUES.TABLE_NAME AND COLUMNS.COLUMN_NAME = UNIQUES.FIRST_COLUMN_NAME ) WHERE @@ -667,7 +664,7 @@ func (this *Inspector) getCandidateUniqueKeys(tableName string) (uniqueKeys [](* if err != nil { return uniqueKeys, err } - log.Debugf("Potential unique keys in %+v: %+v", tableName, uniqueKeys) + this.migrationContext.Log.Debugf("Potential unique keys in %+v: %+v", tableName, uniqueKeys) return uniqueKeys, nil } @@ -687,21 +684,34 @@ func (this *Inspector) getSharedUniqueKeys(originalUniqueKeys, ghostUniqueKeys [ } // getSharedColumns returns the intersection of two lists of columns in same order as the first list -func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.ColumnList, columnRenameMap map[string]string) (*sql.ColumnList, *sql.ColumnList) { +func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.ColumnList, originalVirtualColumns, ghostVirtualColumns *sql.ColumnList, columnRenameMap map[string]string) (*sql.ColumnList, *sql.ColumnList) { sharedColumnNames := []string{} for _, originalColumn := range originalColumns.Names() { isSharedColumn := false for _, ghostColumn := range ghostColumns.Names() { if strings.EqualFold(originalColumn, ghostColumn) { isSharedColumn = true + break } if strings.EqualFold(columnRenameMap[originalColumn], ghostColumn) { isSharedColumn = true + break } } for droppedColumn := range this.migrationContext.DroppedColumnsMap { if strings.EqualFold(originalColumn, droppedColumn) { isSharedColumn = false + break + } + } + for _, virtualColumn := range originalVirtualColumns.Names() { + if strings.EqualFold(originalColumn, virtualColumn) { + isSharedColumn = false + } + } + for _, virtualColumn := range ghostVirtualColumns.Names() { + if strings.EqualFold(originalColumn, virtualColumn) { + isSharedColumn = false } } if isSharedColumn { @@ -744,14 +754,20 @@ func (this *Inspector) readChangelogState(hint string) (string, error) { } func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { - log.Infof("Recursively searching for replication master") + this.migrationContext.Log.Infof("Recursively searching for replication master") visitedKeys := mysql.NewInstanceKeyMap() return mysql.GetMasterConnectionConfigSafe(this.connectionConfig, visitedKeys, this.migrationContext.AllowedMasterMaster) } func (this *Inspector) getReplicationLag() (replicationLag time.Duration, err error) { - replicationLag, err = mysql.GetReplicationLag( - this.migrationContext.InspectorConnectionConfig, + replicationLag, err = mysql.GetReplicationLagFromSlaveStatus( + this.informationSchemaDb, ) return replicationLag, err } + +func (this *Inspector) Teardown() { + this.db.Close() + this.informationSchemaDb.Close() + return +} diff --git a/go/logic/migrator.go b/go/logic/migrator.go index b320161..291a490 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -11,6 +11,7 @@ import ( "math" "os" "strings" + "sync" "sync/atomic" "time" @@ -18,8 +19,6 @@ import ( "github.com/github/gh-ost/go/binlog" "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - - "github.com/outbrain/golib/log" ) type ChangelogState string @@ -62,7 +61,7 @@ const ( // Migrator is the main schema migration flow manager. type Migrator struct { - parser *sql.Parser + parser *sql.AlterTableParser inspector *Inspector applier *Applier eventsStreamer *EventsStreamer @@ -78,17 +77,19 @@ type Migrator struct { rowCopyCompleteFlag int64 // copyRowsQueue should not be buffered; if buffered some non-damaging but - // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete + // excessive work happens at the end of the iteration as new copy-jobs arrive before realizing the copy is complete copyRowsQueue chan tableWriteFunc applyEventsQueue chan *applyEventStruct handledChangelogStates map[string]bool + + finishedMigrating int64 } -func NewMigrator() *Migrator { +func NewMigrator(context *base.MigrationContext) *Migrator { migrator := &Migrator{ - migrationContext: base.GetMigrationContext(), - parser: sql.NewParser(), + migrationContext: context, + parser: sql.NewAlterTableParser(), ghostTableMigrated: make(chan bool), firstThrottlingCollected: make(chan bool, 3), rowCopyComplete: make(chan error), @@ -97,13 +98,14 @@ func NewMigrator() *Migrator { copyRowsQueue: make(chan tableWriteFunc), applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), handledChangelogStates: make(map[string]bool), + finishedMigrating: 0, } return migrator } // initiateHooksExecutor func (this *Migrator) initiateHooksExecutor() (err error) { - this.hooksExecutor = NewHooksExecutor() + this.hooksExecutor = NewHooksExecutor(this.migrationContext) if err := this.hooksExecutor.initHooks(); err != nil { return err } @@ -146,6 +148,34 @@ func (this *Migrator) retryOperation(operation func() error, notFatalHint ...boo return err } +// `retryOperationWithExponentialBackoff` attempts running given function, waiting 2^(n-1) +// seconds between each attempt, where `n` is the running number of attempts. Exits +// as soon as the function returns with non-error, or as soon as `MaxRetries` +// attempts are reached. Wait intervals between attempts obey a maximum of +// `ExponentialBackoffMaxInterval`. +func (this *Migrator) retryOperationWithExponentialBackoff(operation func() error, notFatalHint ...bool) (err error) { + var interval int64 + maxRetries := int(this.migrationContext.MaxRetries()) + maxInterval := this.migrationContext.ExponentialBackoffMaxInterval + for i := 0; i < maxRetries; i++ { + newInterval := int64(math.Exp2(float64(i - 1))) + if newInterval <= maxInterval { + interval = newInterval + } + if i != 0 { + time.Sleep(time.Duration(interval) * time.Second) + } + err = operation() + if err == nil { + return nil + } + } + if len(notFatalHint) == 0 { + this.migrationContext.PanicAbort <- err + } + return err +} + // executeAndThrottleOnError executes a given function. If it errors, it // throttles. func (this *Migrator) executeAndThrottleOnError(operation func() error) (err error) { @@ -179,13 +209,13 @@ func (this *Migrator) canStopStreaming() bool { // onChangelogStateEvent is called when a binlog event operation on the changelog table is intercepted. func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { - // Hey, I created the changlog table, I know the type of columns it has! + // Hey, I created the changelog table, I know the type of columns it has! if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" { return nil } changelogStateString := dmlEvent.NewColumnValues.StringColumn(3) changelogState := ReadChangelogState(changelogStateString) - log.Infof("Intercepted changelog state %s", changelogState) + this.migrationContext.Log.Infof("Intercepted changelog state %s", changelogState) switch changelogState { case GhostTableMigrated: { @@ -211,26 +241,30 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return fmt.Errorf("Unknown changelog state: %+v", changelogState) } } - log.Infof("Handled changelog state %s", changelogState) + this.migrationContext.Log.Infof("Handled changelog state %s", changelogState) return nil } // listenOnPanicAbort aborts on abort request func (this *Migrator) listenOnPanicAbort() { err := <-this.migrationContext.PanicAbort - log.Fatale(err) + this.migrationContext.Log.Fatale(err) } // validateStatement validates the `alter` statement meets criteria. // At this time this means: // - column renames are approved +// - no table rename allowed func (this *Migrator) validateStatement() (err error) { + if this.parser.IsRenameTable() { + return fmt.Errorf("ALTER statement seems to RENAME the table. This is not supported, and you should run your RENAME outside gh-ost.") + } if this.parser.HasNonTrivialRenames() && !this.migrationContext.SkipRenamedColumns { this.migrationContext.ColumnRenameMap = this.parser.GetNonTrivialRenames() if !this.migrationContext.ApproveRenamedColumns { return fmt.Errorf("gh-ost believes the ALTER statement renames columns, as follows: %v; as precaution, you are asked to confirm gh-ost is correct, and provide with `--approve-renamed-columns`, and we're all happy. Or you can skip renamed columns via `--skip-renamed-columns`, in which case column data may be lost", this.parser.GetNonTrivialRenames()) } - log.Infof("Alter statement has column(s) renamed. gh-ost finds the following renames: %v; --approve-renamed-columns is given and so migration proceeds.", this.parser.GetNonTrivialRenames()) + this.migrationContext.Log.Infof("Alter statement has column(s) renamed. gh-ost finds the following renames: %v; --approve-renamed-columns is given and so migration proceeds.", this.parser.GetNonTrivialRenames()) } this.migrationContext.DroppedColumnsMap = this.parser.DroppedColumnsMap() return nil @@ -242,7 +276,7 @@ func (this *Migrator) countTableRows() (err error) { return nil } if this.migrationContext.Noop { - log.Debugf("Noop operation; not really counting table rows") + this.migrationContext.Log.Debugf("Noop operation; not really counting table rows") return nil } @@ -257,7 +291,7 @@ func (this *Migrator) countTableRows() (err error) { } if this.migrationContext.ConcurrentCountTableRows { - log.Infof("As instructed, counting rows in the background; meanwhile I will use an estimated count, and will update it later on") + this.migrationContext.Log.Infof("As instructed, counting rows in the background; meanwhile I will use an estimated count, and will update it later on") go countRowsFunc() // and we ignore errors, because this turns to be a background job return nil @@ -269,9 +303,9 @@ func (this *Migrator) createFlagFiles() (err error) { if this.migrationContext.PostponeCutOverFlagFile != "" { if !base.FileExists(this.migrationContext.PostponeCutOverFlagFile) { if err := base.TouchFile(this.migrationContext.PostponeCutOverFlagFile); err != nil { - return log.Errorf("--postpone-cut-over-flag-file indicated by gh-ost is unable to create said file: %s", err.Error()) + return this.migrationContext.Log.Errorf("--postpone-cut-over-flag-file indicated by gh-ost is unable to create said file: %s", err.Error()) } - log.Infof("Created postpone-cut-over-flag-file: %s", this.migrationContext.PostponeCutOverFlagFile) + this.migrationContext.Log.Infof("Created postpone-cut-over-flag-file: %s", this.migrationContext.PostponeCutOverFlagFile) } } return nil @@ -279,7 +313,7 @@ func (this *Migrator) createFlagFiles() (err error) { // Migrate executes the complete migration logic. This is *the* major gh-ost function. func (this *Migrator) Migrate() (err error) { - log.Infof("Migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + this.migrationContext.Log.Infof("Migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) this.migrationContext.StartTime = time.Now() if this.migrationContext.Hostname, err = os.Hostname(); err != nil { return err @@ -299,6 +333,11 @@ func (this *Migrator) Migrate() (err error) { if err := this.validateStatement(); err != nil { return err } + + // After this point, we'll need to teardown anything that's been started + // so we don't leave things hanging around + defer this.teardown() + if err := this.initiateInspector(); err != nil { return err } @@ -313,9 +352,9 @@ func (this *Migrator) Migrate() (err error) { } initialLag, _ := this.inspector.getReplicationLag() - log.Infof("Waiting for ghost table to be migrated. Current lag is %+v", initialLag) + this.migrationContext.Log.Infof("Waiting for ghost table to be migrated. Current lag is %+v", initialLag) <-this.ghostTableMigrated - log.Debugf("ghost table migrated") + this.migrationContext.Log.Debugf("ghost table migrated") // Yay! We now know the Ghost and Changelog tables are good to examine! // When running on replica, this means the replica has those tables. When running // on master this is always true, of course, and yet it also implies this knowledge @@ -353,9 +392,9 @@ func (this *Migrator) Migrate() (err error) { this.migrationContext.MarkRowCopyStartTime() go this.initiateStatus() - log.Debugf("Operating until row copy is complete") + this.migrationContext.Log.Debugf("Operating until row copy is complete") this.consumeRowCopyComplete() - log.Infof("Row copy complete") + this.migrationContext.Log.Infof("Row copy complete") if err := this.hooksExecutor.onRowCopyComplete(); err != nil { return err } @@ -364,7 +403,13 @@ func (this *Migrator) Migrate() (err error) { if err := this.hooksExecutor.onBeforeCutOver(); err != nil { return err } - if err := this.retryOperation(this.cutOver); err != nil { + var retrier func(func() error, ...bool) error + if this.migrationContext.CutOverExponentialBackoff { + retrier = this.retryOperationWithExponentialBackoff + } else { + retrier = this.retryOperation + } + if err := retrier(this.cutOver); err != nil { return err } atomic.StoreInt64(&this.migrationContext.CutOverCompleteFlag, 1) @@ -375,7 +420,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.hooksExecutor.onSuccess(); err != nil { return err } - log.Infof("Done migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + this.migrationContext.Log.Infof("Done migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) return nil } @@ -387,7 +432,7 @@ func (this *Migrator) ExecOnFailureHook() (err error) { func (this *Migrator) handleCutOverResult(cutOverError error) (err error) { if this.migrationContext.TestOnReplica { - // We're merly testing, we don't want to keep this state. Rollback the renames as possible + // We're merely testing, we don't want to keep this state. Rollback the renames as possible this.applier.RenameTablesRollback() } if cutOverError == nil { @@ -401,14 +446,14 @@ func (this *Migrator) handleCutOverResult(cutOverError error) (err error) { // and swap the tables. // The difference is that we will later swap the tables back. if err := this.hooksExecutor.onStartReplication(); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } if this.migrationContext.TestOnReplicaSkipReplicaStop { - log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not starting replication.") + this.migrationContext.Log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not starting replication.") } else { - log.Debugf("testing on replica. Starting replication IO thread after cut-over failure") + this.migrationContext.Log.Debugf("testing on replica. Starting replication IO thread after cut-over failure") if err := this.retryOperation(this.applier.StartReplication); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } } } @@ -419,16 +464,16 @@ func (this *Migrator) handleCutOverResult(cutOverError error) (err error) { // type (on replica? atomic? safe?) func (this *Migrator) cutOver() (err error) { if this.migrationContext.Noop { - log.Debugf("Noop operation; not really swapping tables") + this.migrationContext.Log.Debugf("Noop operation; not really swapping tables") return nil } this.migrationContext.MarkPointOfInterest() this.throttler.throttle(func() { - log.Debugf("throttling before swapping tables") + this.migrationContext.Log.Debugf("throttling before swapping tables") }) this.migrationContext.MarkPointOfInterest() - log.Debugf("checking for cut-over postpone") + this.migrationContext.Log.Debugf("checking for cut-over postpone") this.sleepWhileTrue( func() (bool, error) { if this.migrationContext.PostponeCutOverFlagFile == "" { @@ -453,7 +498,7 @@ func (this *Migrator) cutOver() (err error) { ) atomic.StoreInt64(&this.migrationContext.IsPostponingCutOver, 0) this.migrationContext.MarkPointOfInterest() - log.Debugf("checking for cut-over postpone: complete") + this.migrationContext.Log.Debugf("checking for cut-over postpone: complete") if this.migrationContext.TestOnReplica { // With `--test-on-replica` we stop replication thread, and then proceed to use @@ -464,9 +509,9 @@ func (this *Migrator) cutOver() (err error) { return err } if this.migrationContext.TestOnReplicaSkipReplicaStop { - log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not stopping replication.") + this.migrationContext.Log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not stopping replication.") } else { - log.Debugf("testing on replica. Stopping replication IO thread") + this.migrationContext.Log.Debugf("testing on replica. Stopping replication IO thread") if err := this.retryOperation(this.applier.StopReplication); err != nil { return err } @@ -484,7 +529,7 @@ func (this *Migrator) cutOver() (err error) { this.handleCutOverResult(err) return err } - return log.Fatalf("Unknown cut-over type: %d; should never get here!", this.migrationContext.CutOverType) + return this.migrationContext.Log.Fatalf("Unknown cut-over type: %d; should never get here!", this.migrationContext.CutOverType) } // Inject the "AllEventsUpToLockProcessed" state hint, wait for it to appear in the binary logs, @@ -496,32 +541,32 @@ func (this *Migrator) waitForEventsUpToLock() (err error) { waitForEventsUpToLockStartTime := time.Now() allEventsUpToLockProcessedChallenge := fmt.Sprintf("%s:%d", string(AllEventsUpToLockProcessed), waitForEventsUpToLockStartTime.UnixNano()) - log.Infof("Writing changelog state: %+v", allEventsUpToLockProcessedChallenge) + this.migrationContext.Log.Infof("Writing changelog state: %+v", allEventsUpToLockProcessedChallenge) if _, err := this.applier.WriteChangelogState(allEventsUpToLockProcessedChallenge); err != nil { return err } - log.Infof("Waiting for events up to lock") + this.migrationContext.Log.Infof("Waiting for events up to lock") atomic.StoreInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 1) for found := false; !found; { select { case <-timeout.C: { - return log.Errorf("Timeout while waiting for events up to lock") + return this.migrationContext.Log.Errorf("Timeout while waiting for events up to lock") } case state := <-this.allEventsUpToLockProcessed: { if state == allEventsUpToLockProcessedChallenge { - log.Infof("Waiting for events up to lock: got %s", state) + this.migrationContext.Log.Infof("Waiting for events up to lock: got %s", state) found = true } else { - log.Infof("Waiting for events up to lock: skipping %s", state) + this.migrationContext.Log.Infof("Waiting for events up to lock: skipping %s", state) } } } } waitForEventsUpToLockDuration := time.Since(waitForEventsUpToLockStartTime) - log.Infof("Done waiting for events up to lock; duration=%+v", waitForEventsUpToLockDuration) + this.migrationContext.Log.Infof("Done waiting for events up to lock; duration=%+v", waitForEventsUpToLockDuration) this.printStatus(ForcePrintStatusAndHintRule) return nil @@ -552,7 +597,7 @@ func (this *Migrator) cutOverTwoStep() (err error) { lockAndRenameDuration := this.migrationContext.RenameTablesEndTime.Sub(this.migrationContext.LockTablesStartTime) renameDuration := this.migrationContext.RenameTablesEndTime.Sub(this.migrationContext.RenameTablesStartTime) - log.Debugf("Lock & rename duration: %s (rename only: %s). During this time, queries on %s were locked or failing", lockAndRenameDuration, renameDuration, sql.EscapeName(this.migrationContext.OriginalTableName)) + this.migrationContext.Log.Debugf("Lock & rename duration: %s (rename only: %s). During this time, queries on %s were locked or failing", lockAndRenameDuration, renameDuration, sql.EscapeName(this.migrationContext.OriginalTableName)) return nil } @@ -562,9 +607,12 @@ func (this *Migrator) atomicCutOver() (err error) { defer atomic.StoreInt64(&this.migrationContext.InCutOverCriticalSectionFlag, 0) okToUnlockTable := make(chan bool, 4) + var dropCutOverSentryTableOnce sync.Once defer func() { okToUnlockTable <- true - this.applier.DropAtomicCutOverSentryTableIfExists() + dropCutOverSentryTableOnce.Do(func() { + this.applier.DropAtomicCutOverSentryTableIfExists() + }) }() atomic.StoreInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 0) @@ -573,19 +621,19 @@ func (this *Migrator) atomicCutOver() (err error) { tableLocked := make(chan error, 2) tableUnlocked := make(chan error, 2) go func() { - if err := this.applier.AtomicCutOverMagicLock(lockOriginalSessionIdChan, tableLocked, okToUnlockTable, tableUnlocked); err != nil { - log.Errore(err) + if err := this.applier.AtomicCutOverMagicLock(lockOriginalSessionIdChan, tableLocked, okToUnlockTable, tableUnlocked, &dropCutOverSentryTableOnce); err != nil { + this.migrationContext.Log.Errore(err) } }() if err := <-tableLocked; err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } lockOriginalSessionId := <-lockOriginalSessionIdChan - log.Infof("Session locking original & magic tables is %+v", lockOriginalSessionId) + this.migrationContext.Log.Infof("Session locking original & magic tables is %+v", lockOriginalSessionId) // At this point we know the original table is locked. // We know any newly incoming DML on original table is blocked. if err := this.waitForEventsUpToLock(); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } // Step 2 @@ -603,7 +651,7 @@ func (this *Migrator) atomicCutOver() (err error) { } }() renameSessionId := <-renameSessionIdChan - log.Infof("Session renaming tables is %+v", renameSessionId) + this.migrationContext.Log.Infof("Session renaming tables is %+v", renameSessionId) waitForRename := func() error { if atomic.LoadInt64(&tableRenameKnownToHaveFailed) == 1 { @@ -620,13 +668,13 @@ func (this *Migrator) atomicCutOver() (err error) { return err } if atomic.LoadInt64(&tableRenameKnownToHaveFailed) == 0 { - log.Infof("Found atomic RENAME to be blocking, as expected. Double checking the lock is still in place (though I don't strictly have to)") + this.migrationContext.Log.Infof("Found atomic RENAME to be blocking, as expected. Double checking the lock is still in place (though I don't strictly have to)") } if err := this.applier.ExpectUsedLock(lockOriginalSessionId); err != nil { // Abort operation. Just make sure to drop the magic table. - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } - log.Infof("Connection holding lock on original table still exists") + this.migrationContext.Log.Infof("Connection holding lock on original table still exists") // Now that we've found the RENAME blocking, AND the locking connection still alive, // we know it is safe to proceed to release the lock @@ -635,16 +683,16 @@ func (this *Migrator) atomicCutOver() (err error) { // BAM! magic table dropped, original table lock is released // -> RENAME released -> queries on original are unblocked. if err := <-tableUnlocked; err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } if err := <-tablesRenamed; err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } this.migrationContext.RenameTablesEndTime = time.Now() // ooh nice! We're actually truly and thankfully done lockAndRenameDuration := this.migrationContext.RenameTablesEndTime.Sub(this.migrationContext.LockTablesStartTime) - log.Infof("Lock & rename duration: %s. During this time, queries on %s were blocked", lockAndRenameDuration, sql.EscapeName(this.migrationContext.OriginalTableName)) + this.migrationContext.Log.Infof("Lock & rename duration: %s. During this time, queries on %s were blocked", lockAndRenameDuration, sql.EscapeName(this.migrationContext.OriginalTableName)) return nil } @@ -653,7 +701,7 @@ func (this *Migrator) initiateServer() (err error) { var f printStatusFunc = func(rule PrintStatusRule, writer io.Writer) { this.printStatus(rule, writer) } - this.server = NewServer(this.hooksExecutor, f) + this.server = NewServer(this.migrationContext, this.hooksExecutor, f) if err := this.server.BindSocketFile(); err != nil { return err } @@ -673,7 +721,7 @@ func (this *Migrator) initiateServer() (err error) { // - heartbeat // When `--allow-on-master` is supplied, the inspector is actually the master. func (this *Migrator) initiateInspector() (err error) { - this.inspector = NewInspector() + this.inspector = NewInspector(this.migrationContext) if err := this.inspector.InitDBConnections(); err != nil { return err } @@ -690,10 +738,10 @@ func (this *Migrator) initiateInspector() (err error) { if this.migrationContext.ApplierConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { return err } - log.Infof("Master found to be %+v", *this.migrationContext.ApplierConnectionConfig.ImpliedKey) + this.migrationContext.Log.Infof("Master found to be %+v", *this.migrationContext.ApplierConnectionConfig.ImpliedKey) } else { // Forced master host. - key, err := mysql.ParseRawInstanceKeyLoose(this.migrationContext.AssumeMasterHostname) + key, err := mysql.ParseInstanceKey(this.migrationContext.AssumeMasterHostname) if err != nil { return err } @@ -704,14 +752,14 @@ func (this *Migrator) initiateInspector() (err error) { if this.migrationContext.CliMasterPassword != "" { this.migrationContext.ApplierConnectionConfig.Password = this.migrationContext.CliMasterPassword } - log.Infof("Master forced to be %+v", *this.migrationContext.ApplierConnectionConfig.ImpliedKey) + this.migrationContext.Log.Infof("Master forced to be %+v", *this.migrationContext.ApplierConnectionConfig.ImpliedKey) } // validate configs if this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica { if this.migrationContext.InspectorIsAlsoApplier() { return fmt.Errorf("Instructed to --test-on-replica or --migrate-on-replica, but the server we connect to doesn't seem to be a replica") } - log.Infof("--test-on-replica or --migrate-on-replica given. Will not execute on master %+v but rather on replica %+v itself", + this.migrationContext.Log.Infof("--test-on-replica or --migrate-on-replica given. Will not execute on master %+v but rather on replica %+v itself", *this.migrationContext.ApplierConnectionConfig.ImpliedKey, *this.migrationContext.InspectorConnectionConfig.ImpliedKey, ) this.migrationContext.ApplierConnectionConfig = this.migrationContext.InspectorConnectionConfig.Duplicate() @@ -733,6 +781,9 @@ func (this *Migrator) initiateStatus() error { this.printStatus(ForcePrintStatusAndHintRule) statusTick := time.Tick(1 * time.Second) for range statusTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return nil + } go this.printStatus(HeuristicPrintStatusRule) } @@ -742,7 +793,7 @@ func (this *Migrator) initiateStatus() error { // printMigrationStatusHint prints a detailed configuration dump, that is useful // to keep in mind; such as the name of migrated table, throttle params etc. // This gets printed at beginning and end of migration, every 10 minutes throughout -// migration, and as reponse to the "status" interactive command. +// migration, and as response to the "status" interactive command. func (this *Migrator) printMigrationStatusHint(writers ...io.Writer) { w := io.MultiWriter(writers...) fmt.Fprintln(w, fmt.Sprintf("# Migrating %s.%s; Ghost table is %s.%s", @@ -820,7 +871,7 @@ func (this *Migrator) printMigrationStatusHint(writers ...io.Writer) { } } -// printStatus prints the prgoress status, and optionally additionally detailed +// printStatus prints the progress status, and optionally additionally detailed // dump of configuration. // `rule` indicates the type of output expected. // By default the status is written to standard output, but other writers can @@ -846,6 +897,8 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { } else { progressPct = 100.0 * float64(totalRowsCopied) / float64(rowsEstimate) } + // we take the opportunity to update migration context with progressPct + this.migrationContext.SetProgressPct(progressPct) // Before status, let's see if we should print a nice reminder for what exactly we're doing here. shouldPrintMigrationStatusHint := (elapsedSeconds%600 == 0) if rule == ForcePrintStatusAndHintRule { @@ -862,7 +915,7 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { eta := "N/A" if progressPct >= 100.0 { eta = "due" - } else if progressPct >= 1.0 { + } else if progressPct >= 0.1 { elapsedRowCopySeconds := this.migrationContext.ElapsedRowCopyTime().Seconds() totalExpectedSeconds := elapsedRowCopySeconds * float64(rowsEstimate) / float64(totalRowsCopied) etaSeconds = totalExpectedSeconds - elapsedRowCopySeconds @@ -909,12 +962,13 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { currentBinlogCoordinates := *this.eventsStreamer.GetCurrentBinlogCoordinates() - status := fmt.Sprintf("Copy: %d/%d %.1f%%; Applied: %d; Backlog: %d/%d; Time: %+v(total), %+v(copy); streamer: %+v; State: %s; ETA: %s", + status := fmt.Sprintf("Copy: %d/%d %.1f%%; Applied: %d; Backlog: %d/%d; Time: %+v(total), %+v(copy); streamer: %+v; Lag: %.2fs, State: %s; ETA: %s", totalRowsCopied, rowsEstimate, progressPct, atomic.LoadInt64(&this.migrationContext.TotalDMLEventsApplied), len(this.applyEventsQueue), cap(this.applyEventsQueue), base.PrettifyDurationOutput(elapsedTime), base.PrettifyDurationOutput(this.migrationContext.ElapsedRowCopyTime()), currentBinlogCoordinates, + this.migrationContext.GetCurrentLagDuration().Seconds(), state, eta, ) @@ -932,7 +986,7 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { // initiateStreaming begins streaming of binary log events and registers listeners for such events func (this *Migrator) initiateStreaming() error { - this.eventsStreamer = NewEventsStreamer() + this.eventsStreamer = NewEventsStreamer(this.migrationContext) if err := this.eventsStreamer.InitDBConnections(); err != nil { return err } @@ -946,17 +1000,20 @@ func (this *Migrator) initiateStreaming() error { ) go func() { - log.Debugf("Beginning streaming") + this.migrationContext.Log.Debugf("Beginning streaming") err := this.eventsStreamer.StreamEvents(this.canStopStreaming) if err != nil { this.migrationContext.PanicAbort <- err } - log.Debugf("Done streaming") + this.migrationContext.Log.Debugf("Done streaming") }() go func() { ticker := time.Tick(1 * time.Second) for range ticker { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } this.migrationContext.SetRecentBinlogCoordinates(*this.eventsStreamer.GetCurrentBinlogCoordinates()) } }() @@ -980,21 +1037,21 @@ func (this *Migrator) addDMLEventsListener() error { // initiateThrottler kicks in the throttling collection and the throttling checks. func (this *Migrator) initiateThrottler() error { - this.throttler = NewThrottler(this.applier, this.inspector) + this.throttler = NewThrottler(this.migrationContext, this.applier, this.inspector) go this.throttler.initiateThrottlerCollection(this.firstThrottlingCollected) - log.Infof("Waiting for first throttle metrics to be collected") + this.migrationContext.Log.Infof("Waiting for first throttle metrics to be collected") <-this.firstThrottlingCollected // replication lag <-this.firstThrottlingCollected // HTTP status <-this.firstThrottlingCollected // other, general metrics - log.Infof("First throttle metrics collected") + this.migrationContext.Log.Infof("First throttle metrics collected") go this.throttler.initiateThrottlerChecks() return nil } func (this *Migrator) initiateApplier() error { - this.applier = NewApplier() + this.applier = NewApplier(this.migrationContext) if err := this.applier.InitDBConnections(); err != nil { return err } @@ -1002,16 +1059,16 @@ func (this *Migrator) initiateApplier() error { return err } if err := this.applier.CreateChangelogTable(); err != nil { - log.Errorf("Unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") + this.migrationContext.Log.Errorf("Unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") return err } if err := this.applier.CreateGhostTable(); err != nil { - log.Errorf("Unable to create ghost table, see further error details. Perhaps a previous migration failed without dropping the table? Bailing out") + this.migrationContext.Log.Errorf("Unable to create ghost table, see further error details. Perhaps a previous migration failed without dropping the table? Bailing out") return err } if err := this.applier.AlterGhost(); err != nil { - log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out") + this.migrationContext.Log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out") return err } @@ -1025,34 +1082,43 @@ func (this *Migrator) initiateApplier() error { func (this *Migrator) iterateChunks() error { terminateRowIteration := func(err error) error { this.rowCopyComplete <- err - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } if this.migrationContext.Noop { - log.Debugf("Noop operation; not really copying data") + this.migrationContext.Log.Debugf("Noop operation; not really copying data") return terminateRowIteration(nil) } if this.migrationContext.MigrationRangeMinValues == nil { - log.Debugf("No rows found in table. Rowcopy will be implicitly empty") + this.migrationContext.Log.Debugf("No rows found in table. Rowcopy will be implicitly empty") return terminateRowIteration(nil) } + + var hasNoFurtherRangeFlag int64 // Iterate per chunk: for { - if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { + if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { // Done // There's another such check down the line return nil } copyRowsFunc := func() error { - if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { + if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { // Done. // There's another such check down the line return nil } - hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() - if err != nil { + + // When hasFurtherRange is false, original table might be write locked and CalculateNextIterationRangeEndValues would hangs forever + + hasFurtherRange := false + if err := this.retryOperation(func() (e error) { + hasFurtherRange, e = this.applier.CalculateNextIterationRangeEndValues() + return e + }); err != nil { return terminateRowIteration(err) } if !hasFurtherRange { + atomic.StoreInt64(&hasNoFurtherRangeFlag, 1) return terminateRowIteration(nil) } // Copy task: @@ -1070,7 +1136,7 @@ func (this *Migrator) iterateChunks() error { } _, rowsAffected, _, err := this.applier.ApplyIterationInsertQuery() if err != nil { - return terminateRowIteration(err) + return err // wrapping call will retry } atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected) atomic.AddInt64(&this.migrationContext.Iteration, 1) @@ -1091,7 +1157,7 @@ func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { handleNonDMLEventStruct := func(eventStruct *applyEventStruct) error { if eventStruct.writeFunc != nil { if err := this.retryOperation(*eventStruct.writeFunc); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } } return nil @@ -1125,13 +1191,13 @@ func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { return this.applier.ApplyDMLEventQueries(dmlEvents) } if err := this.retryOperation(applyEventFunc); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } if nonDmlStructToApply != nil { // We pulled DML events from the queue, and then we hit a non-DML event. Wait! // We need to handle it! if err := handleNonDMLEventStruct(nonDmlStructToApply); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } } } @@ -1143,10 +1209,14 @@ func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { // Both event backlog and rowcopy events are polled; the backlog events have precedence. func (this *Migrator) executeWriteFuncs() error { if this.migrationContext.Noop { - log.Debugf("Noop operation; not really executing write funcs") + this.migrationContext.Log.Debugf("Noop operation; not really executing write funcs") return nil } for { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return nil + } + this.throttler.throttle(nil) // We give higher priority to event processing, then secondary priority to @@ -1166,7 +1236,7 @@ func (this *Migrator) executeWriteFuncs() error { copyRowsStartTime := time.Now() // Retries are handled within the copyRowsFunc if err := copyRowsFunc(); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } if niceRatio := this.migrationContext.GetNiceRatio(); niceRatio > 0 { copyRowsDuration := time.Since(copyRowsStartTime) @@ -1179,7 +1249,7 @@ func (this *Migrator) executeWriteFuncs() error { { // Hmmmmm... nothing in the queue; no events, but also no row copy. // This is possible upon load. Let's just sleep it over. - log.Debugf("Getting nothing in the write queue. Sleeping...") + this.migrationContext.Log.Debugf("Getting nothing in the write queue. Sleeping...") time.Sleep(time.Second) } } @@ -1195,14 +1265,14 @@ func (this *Migrator) finalCleanup() error { if this.migrationContext.Noop { if createTableStatement, err := this.inspector.showCreateTable(this.migrationContext.GetGhostTableName()); err == nil { - log.Infof("New table structure follows") + this.migrationContext.Log.Infof("New table structure follows") fmt.Println(createTableStatement) } else { - log.Errore(err) + this.migrationContext.Log.Errore(err) } } if err := this.eventsStreamer.Close(); err != nil { - log.Errore(err) + this.migrationContext.Log.Errore(err) } if err := this.retryOperation(this.applier.DropChangelogTable); err != nil { @@ -1214,8 +1284,8 @@ func (this *Migrator) finalCleanup() error { } } else { if !this.migrationContext.Noop { - log.Infof("Am not dropping old table because I want this operation to be as live as possible. If you insist I should do it, please add `--ok-to-drop-table` next time. But I prefer you do not. To drop the old table, issue:") - log.Infof("-- drop table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetOldTableName())) + this.migrationContext.Log.Infof("Am not dropping old table because I want this operation to be as live as possible. If you insist I should do it, please add `--ok-to-drop-table` next time. But I prefer you do not. To drop the old table, issue:") + this.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetOldTableName())) } } if this.migrationContext.Noop { @@ -1226,3 +1296,27 @@ func (this *Migrator) finalCleanup() error { return nil } + +func (this *Migrator) teardown() { + atomic.StoreInt64(&this.finishedMigrating, 1) + + if this.inspector != nil { + this.migrationContext.Log.Infof("Tearing down inspector") + this.inspector.Teardown() + } + + if this.applier != nil { + this.migrationContext.Log.Infof("Tearing down applier") + this.applier.Teardown() + } + + if this.eventsStreamer != nil { + this.migrationContext.Log.Infof("Tearing down streamer") + this.eventsStreamer.Teardown() + } + + if this.throttler != nil { + this.migrationContext.Log.Infof("Tearing down throttler") + this.throttler.Teardown() + } +} diff --git a/go/logic/server.go b/go/logic/server.go index 7034cfd..1606884 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -16,7 +16,6 @@ import ( "sync/atomic" "github.com/github/gh-ost/go/base" - "github.com/outbrain/golib/log" ) type printStatusFunc func(PrintStatusRule, io.Writer) @@ -30,9 +29,9 @@ type Server struct { printStatus printStatusFunc } -func NewServer(hooksExecutor *HooksExecutor, printStatus printStatusFunc) *Server { +func NewServer(migrationContext *base.MigrationContext, hooksExecutor *HooksExecutor, printStatus printStatusFunc) *Server { return &Server{ - migrationContext: base.GetMigrationContext(), + migrationContext: migrationContext, hooksExecutor: hooksExecutor, printStatus: printStatus, } @@ -49,12 +48,12 @@ func (this *Server) BindSocketFile() (err error) { if err != nil { return err } - log.Infof("Listening on unix socket file: %s", this.migrationContext.ServeSocketFile) + this.migrationContext.Log.Infof("Listening on unix socket file: %s", this.migrationContext.ServeSocketFile) return nil } func (this *Server) RemoveSocketFile() (err error) { - log.Infof("Removing socket file: %s", this.migrationContext.ServeSocketFile) + this.migrationContext.Log.Infof("Removing socket file: %s", this.migrationContext.ServeSocketFile) return os.Remove(this.migrationContext.ServeSocketFile) } @@ -66,7 +65,7 @@ func (this *Server) BindTCPPort() (err error) { if err != nil { return err } - log.Infof("Listening on tcp port: %d", this.migrationContext.ServeTCPPort) + this.migrationContext.Log.Infof("Listening on tcp port: %d", this.migrationContext.ServeTCPPort) return nil } @@ -76,7 +75,7 @@ func (this *Server) Serve() (err error) { for { conn, err := this.unixListener.Accept() if err != nil { - log.Errore(err) + this.migrationContext.Log.Errore(err) } go this.handleConnection(conn) } @@ -88,7 +87,7 @@ func (this *Server) Serve() (err error) { for { conn, err := this.tcpListener.Accept() if err != nil { - log.Errore(err) + this.migrationContext.Log.Errore(err) } go this.handleConnection(conn) } @@ -118,7 +117,7 @@ func (this *Server) onServerCommand(command string, writer *bufio.Writer) (err e } else { fmt.Fprintf(writer, "%s\n", err.Error()) } - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } // applyServerCommand parses and executes commands by user @@ -130,6 +129,9 @@ func (this *Server) applyServerCommand(command string, writer *bufio.Writer) (pr arg := "" if len(tokens) > 1 { arg = strings.TrimSpace(tokens[1]) + if unquoted, err := strconv.Unquote(arg); err == nil { + arg = unquoted + } } argIsQuestion := (arg == "?") throttleHint := "# Note: you may only throttle for as long as your binary logs are not purged\n" @@ -141,13 +143,13 @@ func (this *Server) applyServerCommand(command string, writer *bufio.Writer) (pr switch command { case "help": { - fmt.Fprintln(writer, `available commands: + fmt.Fprint(writer, `available commands: status # Print a detailed status message sup # Print a short status message coordinates # Print the currently inspected coordinates chunk-size= # Set a new chunk-size dml-batch-size= # Set a new dml-batch-size -nice-ratio= # Set a new nice-ratio, immediate sleep after each row-copy operation, float (examples: 0 is agrressive, 0.7 adds 70% runtime, 1.0 doubles runtime, 2.0 triples runtime, ...) +nice-ratio= # Set a new nice-ratio, immediate sleep after each row-copy operation, float (examples: 0 is aggressive, 0.7 adds 70% runtime, 1.0 doubles runtime, 2.0 triples runtime, ...) critical-load= # Set a new set of max-load thresholds max-lag-millis= # Set a new replication lag threshold replication-lag-query= # Set a new query that determines replication lag (no quotes) @@ -289,12 +291,22 @@ help # This message } case "throttle", "pause", "suspend": { + if arg != "" && arg != this.migrationContext.OriginalTableName { + // User explicitly provided table name. This is a courtesy protection mechanism + err := fmt.Errorf("User commanded 'throttle' on %s, but migrated table is %s; ignoring request.", arg, this.migrationContext.OriginalTableName) + return NoPrintStatusRule, err + } atomic.StoreInt64(&this.migrationContext.ThrottleCommandedByUser, 1) fmt.Fprintf(writer, throttleHint) return ForcePrintStatusAndHintRule, nil } case "no-throttle", "unthrottle", "resume", "continue": { + if arg != "" && arg != this.migrationContext.OriginalTableName { + // User explicitly provided table name. This is a courtesy protection mechanism + err := fmt.Errorf("User commanded 'no-throttle' on %s, but migrated table is %s; ignoring request.", arg, this.migrationContext.OriginalTableName) + return NoPrintStatusRule, err + } atomic.StoreInt64(&this.migrationContext.ThrottleCommandedByUser, 0) return ForcePrintStatusAndHintRule, nil } @@ -305,8 +317,8 @@ help # This message return NoPrintStatusRule, err } if arg != "" && arg != this.migrationContext.OriginalTableName { - // User exlpicitly provided table name. This is a courtesy protection mechanism - err := fmt.Errorf("User commanded 'unpostpone' on %s, but migrated table is %s; ingoring request.", arg, this.migrationContext.OriginalTableName) + // User explicitly provided table name. This is a courtesy protection mechanism + err := fmt.Errorf("User commanded 'unpostpone' on %s, but migrated table is %s; ignoring request.", arg, this.migrationContext.OriginalTableName) return NoPrintStatusRule, err } if atomic.LoadInt64(&this.migrationContext.IsPostponingCutOver) > 0 { @@ -319,7 +331,16 @@ help # This message } case "panic": { - err := fmt.Errorf("User commanded 'panic'. I will now panic, without cleanup. PANIC!") + if arg == "" && this.migrationContext.ForceNamedPanicCommand { + err := fmt.Errorf("User commanded 'panic' without specifying table name, but --force-named-panic is set") + return NoPrintStatusRule, err + } + if arg != "" && arg != this.migrationContext.OriginalTableName { + // User explicitly provided table name. This is a courtesy protection mechanism + err := fmt.Errorf("User commanded 'panic' on %s, but migrated table is %s; ignoring request.", arg, this.migrationContext.OriginalTableName) + return NoPrintStatusRule, err + } + err := fmt.Errorf("User commanded 'panic'. The migration will be aborted without cleanup. Please drop the gh-ost tables before trying again.") this.migrationContext.PanicAbort <- err return NoPrintStatusRule, err } diff --git a/go/logic/streamer.go b/go/logic/streamer.go index 275af55..a07240c 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -16,7 +16,6 @@ import ( "github.com/github/gh-ost/go/binlog" "github.com/github/gh-ost/go/mysql" - "github.com/outbrain/golib/log" "github.com/outbrain/golib/sqlutils" ) @@ -46,10 +45,10 @@ type EventsStreamer struct { name string } -func NewEventsStreamer() *EventsStreamer { +func NewEventsStreamer(migrationContext *base.MigrationContext) *EventsStreamer { return &EventsStreamer{ - connectionConfig: base.GetMigrationContext().InspectorConnectionConfig, - migrationContext: base.GetMigrationContext(), + connectionConfig: migrationContext.InspectorConnectionConfig, + migrationContext: migrationContext, listeners: [](*BinlogEventListener){}, listenersMutex: &sync.Mutex{}, eventsChannel: make(chan *binlog.BinlogEntry, EventsChannelBufferSize), @@ -106,10 +105,10 @@ func (this *EventsStreamer) notifyListeners(binlogEvent *binlog.BinlogDMLEvent) func (this *EventsStreamer) InitDBConnections() (err error) { EventsStreamerUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(EventsStreamerUri); err != nil { + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, EventsStreamerUri); err != nil { return err } - if _, err := base.ValidateConnection(this.db, this.connectionConfig, this.name); err != nil { + if _, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name); err != nil { return err } if err := this.readCurrentBinlogCoordinates(); err != nil { @@ -124,7 +123,7 @@ func (this *EventsStreamer) InitDBConnections() (err error) { // initBinlogReader creates and connects the reader: we hook up to a MySQL server as a replica func (this *EventsStreamer) initBinlogReader(binlogCoordinates *mysql.BinlogCoordinates) error { - goMySQLReader, err := binlog.NewGoMySQLReader(this.migrationContext.InspectorConnectionConfig) + goMySQLReader, err := binlog.NewGoMySQLReader(this.migrationContext) if err != nil { return err } @@ -162,7 +161,7 @@ func (this *EventsStreamer) readCurrentBinlogCoordinates() error { if !foundMasterStatus { return fmt.Errorf("Got no results from SHOW MASTER STATUS. Bailing out") } - log.Debugf("Streamer binlog coordinates: %+v", *this.initialBinlogCoordinates) + this.migrationContext.Log.Debugf("Streamer binlog coordinates: %+v", *this.initialBinlogCoordinates) return nil } @@ -180,8 +179,15 @@ func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { var successiveFailures int64 var lastAppliedRowsEventHint mysql.BinlogCoordinates for { + if canStopStreaming() { + return nil + } if err := this.binlogReader.StreamEvents(canStopStreaming, this.eventsChannel); err != nil { - log.Infof("StreamEvents encountered unexpected error: %+v", err) + if canStopStreaming() { + return nil + } + + this.migrationContext.Log.Infof("StreamEvents encountered unexpected error: %+v", err) this.migrationContext.MarkPointOfInterest() time.Sleep(ReconnectStreamerSleepSeconds * time.Second) @@ -197,7 +203,7 @@ func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { // Reposition at same binlog file. lastAppliedRowsEventHint = this.binlogReader.LastAppliedRowsEventHint - log.Infof("Reconnecting... Will resume at %+v", lastAppliedRowsEventHint) + this.migrationContext.Log.Infof("Reconnecting... Will resume at %+v", lastAppliedRowsEventHint) if err := this.initBinlogReader(this.GetReconnectBinlogCoordinates()); err != nil { return err } @@ -208,6 +214,11 @@ func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { func (this *EventsStreamer) Close() (err error) { err = this.binlogReader.Close() - log.Infof("Closed streamer connection. err=%+v", err) + this.migrationContext.Log.Infof("Closed streamer connection. err=%+v", err) return err } + +func (this *EventsStreamer) Teardown() { + this.db.Close() + return +} diff --git a/go/logic/throttler.go b/go/logic/throttler.go index ae95b70..d234ea6 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -15,43 +15,45 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/outbrain/golib/log" - "github.com/outbrain/golib/sqlutils" ) var ( - httpStatusMessages map[int]string = map[int]string{ + httpStatusMessages = map[int]string{ 200: "OK", 404: "Not found", 417: "Expectation failed", 429: "Too many requests", 500: "Internal server error", + -1: "Connection error", } // See https://github.com/github/freno/blob/master/doc/http.md - httpStatusFrenoMessages map[int]string = map[int]string{ + httpStatusFrenoMessages = map[int]string{ 200: "OK", 404: "freno: unknown metric", 417: "freno: access forbidden", 429: "freno: threshold exceeded", 500: "freno: internal error", + -1: "freno: connection error", } ) const frenoMagicHint = "freno" -// Throttler collects metrics related to throttling and makes informed decisison +// Throttler collects metrics related to throttling and makes informed decision // whether throttling should take place. type Throttler struct { - migrationContext *base.MigrationContext - applier *Applier - inspector *Inspector + migrationContext *base.MigrationContext + applier *Applier + inspector *Inspector + finishedMigrating int64 } -func NewThrottler(applier *Applier, inspector *Inspector) *Throttler { +func NewThrottler(migrationContext *base.MigrationContext, applier *Applier, inspector *Inspector) *Throttler { return &Throttler{ - migrationContext: base.GetMigrationContext(), - applier: applier, - inspector: inspector, + migrationContext: migrationContext, + applier: applier, + inspector: inspector, + finishedMigrating: 0, } } @@ -83,6 +85,7 @@ func (this *Throttler) shouldThrottle() (result bool, reason string, reasonHint if statusCode != 0 && statusCode != http.StatusOK { return true, this.throttleHttpMessage(int(statusCode)), base.NoThrottleReasonHint } + // Replication lag throttle maxLagMillisecondsThrottleThreshold := atomic.LoadInt64(&this.migrationContext.MaxLagMillisecondsThrottleThreshold) lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) @@ -119,7 +122,7 @@ func parseChangelogHeartbeat(heartbeatValue string) (lag time.Duration, err erro // parseChangelogHeartbeat parses a string timestamp and deduces replication lag func (this *Throttler) parseChangelogHeartbeat(heartbeatValue string) (err error) { if lag, err := parseChangelogHeartbeat(heartbeatValue); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } else { atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) return nil @@ -139,15 +142,15 @@ func (this *Throttler) collectReplicationLag(firstThrottlingCollected chan<- boo if this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica { // when running on replica, the heartbeat injection is also done on the replica. // This means we will always get a good heartbeat value. - // When runnign on replica, we should instead check the `SHOW SLAVE STATUS` output. - if lag, err := mysql.GetReplicationLag(this.inspector.connectionConfig); err != nil { - return log.Errore(err) + // When running on replica, we should instead check the `SHOW SLAVE STATUS` output. + if lag, err := mysql.GetReplicationLagFromSlaveStatus(this.inspector.informationSchemaDb); err != nil { + return this.migrationContext.Log.Errore(err) } else { atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) } } else { if heartbeatValue, err := this.inspector.readChangelogState("heartbeat"); err != nil { - return log.Errore(err) + return this.migrationContext.Log.Errore(err) } else { this.parseChangelogHeartbeat(heartbeatValue) } @@ -160,6 +163,9 @@ func (this *Throttler) collectReplicationLag(firstThrottlingCollected chan<- boo ticker := time.Tick(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) for range ticker { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } go collectFunc() } } @@ -182,11 +188,12 @@ func (this *Throttler) collectControlReplicasLag() { dbUri := connectionConfig.GetDBUri("information_schema") var heartbeatValue string - if db, _, err := sqlutils.GetDB(dbUri); err != nil { + if db, _, err := mysql.GetDB(this.migrationContext.Uuid, dbUri); err != nil { return lag, err } else if err = db.QueryRow(replicationLagQuery).Scan(&heartbeatValue); err != nil { return lag, err } + lag, err = parseChangelogHeartbeat(heartbeatValue) return lag, err } @@ -233,6 +240,9 @@ func (this *Throttler) collectControlReplicasLag() { shouldReadLagAggressively := false for range aggressiveTicker { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } if counter%relaxedFactor == 0 { // we only check if we wish to be aggressive once per second. The parameters for being aggressive // do not typically change at all throughout the migration, but nonetheless we check them. @@ -280,12 +290,31 @@ func (this *Throttler) collectThrottleHTTPStatus(firstThrottlingCollected chan<- return false, nil } - collectFunc() + _, err := collectFunc() + if err != nil { + // If not told to ignore errors, we'll throttle on HTTP connection issues + if !this.migrationContext.IgnoreHTTPErrors { + atomic.StoreInt64(&this.migrationContext.ThrottleHTTPStatusCode, int64(-1)) + } + } + firstThrottlingCollected <- true ticker := time.Tick(100 * time.Millisecond) for range ticker { - if sleep, _ := collectFunc(); sleep { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } + + sleep, err := collectFunc() + if err != nil { + // If not told to ignore errors, we'll throttle on HTTP connection issues + if !this.migrationContext.IgnoreHTTPErrors { + atomic.StoreInt64(&this.migrationContext.ThrottleHTTPStatusCode, int64(-1)) + } + } + + if sleep { time.Sleep(1 * time.Second) } } @@ -318,7 +347,7 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { hibernateDuration := time.Duration(this.migrationContext.CriticalLoadHibernateSeconds) * time.Second hibernateUntilTime := time.Now().Add(hibernateDuration) atomic.StoreInt64(&this.migrationContext.HibernateUntil, hibernateUntilTime.UnixNano()) - log.Errorf("critical-load met: %s=%d, >=%d. Will hibernate for the duration of %+v, until %+v", variableName, value, threshold, hibernateDuration, hibernateUntilTime) + this.migrationContext.Log.Errorf("critical-load met: %s=%d, >=%d. Will hibernate for the duration of %+v, until %+v", variableName, value, threshold, hibernateDuration, hibernateUntilTime) go func() { time.Sleep(hibernateDuration) this.migrationContext.SetThrottleGeneralCheckResult(base.NewThrottleCheckResult(true, "leaving hibernation", base.LeavingHibernationThrottleReasonHint)) @@ -331,7 +360,7 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { this.migrationContext.PanicAbort <- fmt.Errorf("critical-load met: %s=%d, >=%d", variableName, value, threshold) } if criticalLoadMet && this.migrationContext.CriticalLoadIntervalMilliseconds > 0 { - log.Errorf("critical-load met once: %s=%d, >=%d. Will check again in %d millis", variableName, value, threshold, this.migrationContext.CriticalLoadIntervalMilliseconds) + this.migrationContext.Log.Errorf("critical-load met once: %s=%d, >=%d. Will check again in %d millis", variableName, value, threshold, this.migrationContext.CriticalLoadIntervalMilliseconds) go func() { timer := time.NewTimer(time.Millisecond * time.Duration(this.migrationContext.CriticalLoadIntervalMilliseconds)) <-timer.C @@ -393,6 +422,10 @@ func (this *Throttler) initiateThrottlerCollection(firstThrottlingCollected chan throttlerMetricsTick := time.Tick(1 * time.Second) for range throttlerMetricsTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } + this.collectGeneralThrottleMetrics() } }() @@ -419,6 +452,9 @@ func (this *Throttler) initiateThrottlerChecks() error { } throttlerFunction() for range throttlerTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return nil + } throttlerFunction() } @@ -440,3 +476,8 @@ func (this *Throttler) throttle(onThrottled func()) { time.Sleep(250 * time.Millisecond) } } + +func (this *Throttler) Teardown() { + this.migrationContext.Log.Debugf("Tearing down...") + atomic.StoreInt64(&this.finishedMigrating, 1) +} diff --git a/go/mysql/binlog.go b/go/mysql/binlog.go index d98c9e5..50279ce 100644 --- a/go/mysql/binlog.go +++ b/go/mysql/binlog.go @@ -57,7 +57,7 @@ func (this BinlogCoordinates) String() string { return this.DisplayString() } -// Equals tests equality of this corrdinate and another one. +// Equals tests equality of this coordinate and another one. func (this *BinlogCoordinates) Equals(other *BinlogCoordinates) bool { if other == nil { return false @@ -95,8 +95,8 @@ func (this *BinlogCoordinates) FileSmallerThan(other *BinlogCoordinates) bool { return this.LogFile < other.LogFile } -// FileNumberDistance returns the numeric distance between this corrdinate's file number and the other's. -// Effectively it means "how many roatets/FLUSHes would make these coordinates's file reach the other's" +// FileNumberDistance returns the numeric distance between this coordinate's file number and the other's. +// Effectively it means "how many rotates/FLUSHes would make these coordinates's file reach the other's" func (this *BinlogCoordinates) FileNumberDistance(other *BinlogCoordinates) int { thisNumber, _ := this.FileNumber() otherNumber, _ := other.FileNumber() diff --git a/go/mysql/connection.go b/go/mysql/connection.go index 96ae08b..6855ee0 100644 --- a/go/mysql/connection.go +++ b/go/mysql/connection.go @@ -6,8 +6,18 @@ package mysql import ( + "crypto/tls" + "crypto/x509" + "errors" "fmt" + "io/ioutil" "net" + + "github.com/go-sql-driver/mysql" +) + +const ( + TLS_CONFIG_KEY = "ghost" ) // ConnectionConfig is the minimal configuration required to connect to a MySQL server @@ -16,6 +26,8 @@ type ConnectionConfig struct { User string Password string ImpliedKey *InstanceKey + tlsConfig *tls.Config + Timeout float64 } func NewConnectionConfig() *ConnectionConfig { @@ -29,9 +41,11 @@ func NewConnectionConfig() *ConnectionConfig { // DuplicateCredentials creates a new connection config with given key and with same credentials as this config func (this *ConnectionConfig) DuplicateCredentials(key InstanceKey) *ConnectionConfig { config := &ConnectionConfig{ - Key: key, - User: this.User, - Password: this.Password, + Key: key, + User: this.User, + Password: this.Password, + tlsConfig: this.tlsConfig, + Timeout: this.Timeout, } config.ImpliedKey = &config.Key return config @@ -42,13 +56,54 @@ func (this *ConnectionConfig) Duplicate() *ConnectionConfig { } func (this *ConnectionConfig) String() string { - return fmt.Sprintf("%s, user=%s", this.Key.DisplayString(), this.User) + return fmt.Sprintf("%s, user=%s, usingTLS=%t", this.Key.DisplayString(), this.User, this.tlsConfig != nil) } func (this *ConnectionConfig) Equals(other *ConnectionConfig) bool { return this.Key.Equals(&other.Key) || this.ImpliedKey.Equals(other.ImpliedKey) } +func (this *ConnectionConfig) UseTLS(caCertificatePath, clientCertificate, clientKey string, allowInsecure bool) error { + var rootCertPool *x509.CertPool + var certs []tls.Certificate + var err error + + if caCertificatePath == "" { + rootCertPool, err = x509.SystemCertPool() + if err != nil { + return err + } + } else { + rootCertPool = x509.NewCertPool() + pem, err := ioutil.ReadFile(caCertificatePath) + if err != nil { + return err + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + return errors.New("could not add ca certificate to cert pool") + } + } + if clientCertificate != "" || clientKey != "" { + cert, err := tls.LoadX509KeyPair(clientCertificate, clientKey) + if err != nil { + return err + } + certs = []tls.Certificate{cert} + } + + this.tlsConfig = &tls.Config{ + Certificates: certs, + RootCAs: rootCertPool, + InsecureSkipVerify: allowInsecure, + } + + return mysql.RegisterTLSConfig(TLS_CONFIG_KEY, this.tlsConfig) +} + +func (this *ConnectionConfig) TLSConfig() *tls.Config { + return this.tlsConfig +} + func (this *ConnectionConfig) GetDBUri(databaseName string) string { hostname := this.Key.Hostname var ip = net.ParseIP(hostname) @@ -56,5 +111,12 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string { // Wrap IPv6 literals in square brackets hostname = fmt.Sprintf("[%s]", hostname) } - return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?interpolateParams=true&autocommit=true&charset=utf8mb4,utf8,latin1", this.User, this.Password, hostname, this.Key.Port, databaseName) + interpolateParams := true + // go-mysql-driver defaults to false if tls param is not provided; explicitly setting here to + // simplify construction of the DSN below. + tlsOption := "false" + if this.tlsConfig != nil { + tlsOption = TLS_CONFIG_KEY + } + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?timeout=%fs&readTimeout=%fs&writeTimeout=%fs&interpolateParams=%t&autocommit=true&charset=utf8mb4,utf8,latin1&tls=%s", this.User, this.Password, hostname, this.Key.Port, databaseName, this.Timeout, this.Timeout, this.Timeout, interpolateParams, tlsOption) } diff --git a/go/mysql/connection_test.go b/go/mysql/connection_test.go index 657feb0..2befbc9 100644 --- a/go/mysql/connection_test.go +++ b/go/mysql/connection_test.go @@ -6,6 +6,7 @@ package mysql import ( + "crypto/tls" "testing" "github.com/outbrain/golib/log" @@ -31,6 +32,10 @@ func TestDuplicateCredentials(t *testing.T) { c.Key = InstanceKey{Hostname: "myhost", Port: 3306} c.User = "gromit" c.Password = "penguin" + c.tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + ServerName: "feathers", + } dup := c.DuplicateCredentials(InstanceKey{Hostname: "otherhost", Port: 3310}) test.S(t).ExpectEquals(dup.Key.Hostname, "otherhost") @@ -39,6 +44,7 @@ func TestDuplicateCredentials(t *testing.T) { test.S(t).ExpectEquals(dup.ImpliedKey.Port, 3310) test.S(t).ExpectEquals(dup.User, "gromit") test.S(t).ExpectEquals(dup.Password, "penguin") + test.S(t).ExpectEquals(dup.tlsConfig, c.tlsConfig) } func TestDuplicate(t *testing.T) { @@ -55,3 +61,24 @@ func TestDuplicate(t *testing.T) { test.S(t).ExpectEquals(dup.User, "gromit") test.S(t).ExpectEquals(dup.Password, "penguin") } + +func TestGetDBUri(t *testing.T) { + c := NewConnectionConfig() + c.Key = InstanceKey{Hostname: "myhost", Port: 3306} + c.User = "gromit" + c.Password = "penguin" + + uri := c.GetDBUri("test") + test.S(t).ExpectEquals(uri, "gromit:penguin@tcp(myhost:3306)/test?timeout=0.000000s&readTimeout=0.000000s&writeTimeout=0.000000s&interpolateParams=true&autocommit=true&charset=utf8mb4,utf8,latin1&tls=false") +} + +func TestGetDBUriWithTLSSetup(t *testing.T) { + c := NewConnectionConfig() + c.Key = InstanceKey{Hostname: "myhost", Port: 3306} + c.User = "gromit" + c.Password = "penguin" + c.tlsConfig = &tls.Config{} + + uri := c.GetDBUri("test") + test.S(t).ExpectEquals(uri, "gromit:penguin@tcp(myhost:3306)/test?timeout=0.000000s&readTimeout=0.000000s&writeTimeout=0.000000s&interpolateParams=true&autocommit=true&charset=utf8mb4,utf8,latin1&tls=ghost") +} diff --git a/go/mysql/instance_key.go b/go/mysql/instance_key.go index ca5419e..eb108d8 100644 --- a/go/mysql/instance_key.go +++ b/go/mysql/instance_key.go @@ -7,6 +7,7 @@ package mysql import ( "fmt" + "regexp" "strconv" "strings" ) @@ -15,7 +16,14 @@ const ( DefaultInstancePort = 3306 ) -// InstanceKey is an instance indicator, identifued by hostname and port +var ( + ipv4HostPortRegexp = regexp.MustCompile("^([^:]+):([0-9]+)$") + ipv4HostRegexp = regexp.MustCompile("^([^:]+)$") + ipv6HostPortRegexp = regexp.MustCompile("^\\[([:0-9a-fA-F]+)\\]:([0-9]+)$") // e.g. [2001:db8:1f70::999:de8:7648:6e8]:3308 + ipv6HostRegexp = regexp.MustCompile("^([:0-9a-fA-F]+)$") // e.g. 2001:db8:1f70::999:de8:7648:6e8 +) + +// InstanceKey is an instance indicator, identified by hostname and port type InstanceKey struct { Hostname string Port int @@ -25,25 +33,35 @@ const detachHint = "//" // ParseInstanceKey will parse an InstanceKey from a string representation such as 127.0.0.1:3306 func NewRawInstanceKey(hostPort string) (*InstanceKey, error) { - tokens := strings.SplitN(hostPort, ":", 2) - if len(tokens) != 2 { - return nil, fmt.Errorf("Cannot parse InstanceKey from %s. Expected format is host:port", hostPort) + hostname := "" + port := "" + if submatch := ipv4HostPortRegexp.FindStringSubmatch(hostPort); len(submatch) > 0 { + hostname = submatch[1] + port = submatch[2] + } else if submatch := ipv4HostRegexp.FindStringSubmatch(hostPort); len(submatch) > 0 { + hostname = submatch[1] + } else if submatch := ipv6HostPortRegexp.FindStringSubmatch(hostPort); len(submatch) > 0 { + hostname = submatch[1] + port = submatch[2] + } else if submatch := ipv6HostRegexp.FindStringSubmatch(hostPort); len(submatch) > 0 { + hostname = submatch[1] + } else { + return nil, fmt.Errorf("Cannot parse address: %s", hostPort) } - instanceKey := &InstanceKey{Hostname: tokens[0]} - var err error - if instanceKey.Port, err = strconv.Atoi(tokens[1]); err != nil { - return instanceKey, fmt.Errorf("Invalid port: %s", tokens[1]) + instanceKey := &InstanceKey{Hostname: hostname, Port: DefaultInstancePort} + if port != "" { + var err error + if instanceKey.Port, err = strconv.Atoi(port); err != nil { + return instanceKey, fmt.Errorf("Invalid port: %s", port) + } } return instanceKey, nil } -// ParseRawInstanceKeyLoose will parse an InstanceKey from a string representation such as 127.0.0.1:3306. +// ParseInstanceKey will parse an InstanceKey from a string representation such as 127.0.0.1:3306. // The port part is optional; there will be no name resolve -func ParseRawInstanceKeyLoose(hostPort string) (*InstanceKey, error) { - if !strings.Contains(hostPort, ":") { - return &InstanceKey{Hostname: hostPort, Port: DefaultInstancePort}, nil - } +func ParseInstanceKey(hostPort string) (*InstanceKey, error) { return NewRawInstanceKey(hostPort) } @@ -83,7 +101,7 @@ func (this *InstanceKey) IsValid() bool { return len(this.Hostname) > 0 && this.Port > 0 } -// DetachedKey returns an instance key whose hostname is detahced: invalid, but recoverable +// DetachedKey returns an instance key whose hostname is detached: invalid, but recoverable func (this *InstanceKey) DetachedKey() *InstanceKey { if this.IsDetached() { return this @@ -91,7 +109,7 @@ func (this *InstanceKey) DetachedKey() *InstanceKey { return &InstanceKey{Hostname: fmt.Sprintf("%s%s", detachHint, this.Hostname), Port: this.Port} } -// ReattachedKey returns an instance key whose hostname is detahced: invalid, but recoverable +// ReattachedKey returns an instance key whose hostname is detached: invalid, but recoverable func (this *InstanceKey) ReattachedKey() *InstanceKey { if !this.IsDetached() { return this diff --git a/go/mysql/instance_key_map.go b/go/mysql/instance_key_map.go index d0900ef..1065fb9 100644 --- a/go/mysql/instance_key_map.go +++ b/go/mysql/instance_key_map.go @@ -92,7 +92,7 @@ func (this *InstanceKeyMap) ReadCommaDelimitedList(list string) error { } tokens := strings.Split(list, ",") for _, token := range tokens { - key, err := ParseRawInstanceKeyLoose(token) + key, err := ParseInstanceKey(token) if err != nil { return err } diff --git a/go/mysql/instance_key_test.go b/go/mysql/instance_key_test.go new file mode 100644 index 0000000..778a5b3 --- /dev/null +++ b/go/mysql/instance_key_test.go @@ -0,0 +1,74 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package mysql + +import ( + "testing" + + "github.com/outbrain/golib/log" + test "github.com/outbrain/golib/tests" +) + +func init() { + log.SetLevel(log.ERROR) +} + +func TestParseInstanceKey(t *testing.T) { + { + key, err := ParseInstanceKey("myhost:1234") + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(key.Hostname, "myhost") + test.S(t).ExpectEquals(key.Port, 1234) + } + { + key, err := ParseInstanceKey("myhost") + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(key.Hostname, "myhost") + test.S(t).ExpectEquals(key.Port, 3306) + } + { + key, err := ParseInstanceKey("10.0.0.3:3307") + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(key.Hostname, "10.0.0.3") + test.S(t).ExpectEquals(key.Port, 3307) + } + { + key, err := ParseInstanceKey("10.0.0.3") + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(key.Hostname, "10.0.0.3") + test.S(t).ExpectEquals(key.Port, 3306) + } + { + key, err := ParseInstanceKey("[2001:db8:1f70::999:de8:7648:6e8]:3308") + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(key.Hostname, "2001:db8:1f70::999:de8:7648:6e8") + test.S(t).ExpectEquals(key.Port, 3308) + } + { + key, err := ParseInstanceKey("::1") + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(key.Hostname, "::1") + test.S(t).ExpectEquals(key.Port, 3306) + } + { + key, err := ParseInstanceKey("0:0:0:0:0:0:0:0") + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(key.Hostname, "0:0:0:0:0:0:0:0") + test.S(t).ExpectEquals(key.Port, 3306) + } + { + _, err := ParseInstanceKey("[2001:xxxx:1f70::999:de8:7648:6e8]:3308") + test.S(t).ExpectNotNil(err) + } + { + _, err := ParseInstanceKey("10.0.0.4:") + test.S(t).ExpectNotNil(err) + } + { + _, err := ParseInstanceKey("10.0.0.4:5.6.7") + test.S(t).ExpectNotNil(err) + } +} diff --git a/go/mysql/utils.go b/go/mysql/utils.go index b670921..17bb5fc 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -8,6 +8,8 @@ package mysql import ( gosql "database/sql" "fmt" + "strings" + "sync" "time" "github.com/github/gh-ost/go/sql" @@ -33,16 +35,32 @@ func (this *ReplicationLagResult) HasLag() bool { return this.Lag > 0 } -// GetReplicationLag returns replication lag for a given connection config; either by explicit query -// or via SHOW SLAVE STATUS -func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time.Duration, err error) { - dbUri := connectionConfig.GetDBUri("information_schema") - var db *gosql.DB - if db, _, err = sqlutils.GetDB(dbUri); err != nil { - return replicationLag, err - } +// knownDBs is a DB cache by uri +var knownDBs map[string]*gosql.DB = make(map[string]*gosql.DB) +var knownDBsMutex = &sync.Mutex{} - err = sqlutils.QueryRowsMap(db, `show slave status`, func(m sqlutils.RowMap) error { +func GetDB(migrationUuid string, mysql_uri string) (*gosql.DB, bool, error) { + cacheKey := migrationUuid + ":" + mysql_uri + + knownDBsMutex.Lock() + defer func() { + knownDBsMutex.Unlock() + }() + + var exists bool + if _, exists = knownDBs[cacheKey]; !exists { + if db, err := gosql.Open("mysql", mysql_uri); err == nil { + knownDBs[cacheKey] = db + } else { + return db, exists, err + } + } + return knownDBs[cacheKey], exists, nil +} + +// GetReplicationLagFromSlaveStatus returns replication lag for a given db; via SHOW SLAVE STATUS +func GetReplicationLagFromSlaveStatus(informationSchemaDb *gosql.DB) (replicationLag time.Duration, err error) { + err = sqlutils.QueryRowsMap(informationSchemaDb, `show slave status`, func(m sqlutils.RowMap) error { slaveIORunning := m.GetString("Slave_IO_Running") slaveSQLRunning := m.GetString("Slave_SQL_Running") secondsBehindMaster := m.GetNullInt64("Seconds_Behind_Master") @@ -52,15 +70,19 @@ func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time. replicationLag = time.Duration(secondsBehindMaster.Int64) * time.Second return nil }) + return replicationLag, err } func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey *InstanceKey, err error) { currentUri := connectionConfig.GetDBUri("information_schema") - db, _, err := sqlutils.GetDB(currentUri) + // This function is only called once, okay to not have a cached connection pool + db, err := gosql.Open("mysql", currentUri) if err != nil { return nil, err } + defer db.Close() + err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { // We wish to recognize the case where the topology's master actually has replication configuration. // This can happen when a DBA issues a `RESET SLAVE` instead of `RESET SLAVE ALL`. @@ -73,7 +95,6 @@ func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey slaveIORunning := rowMap.GetString("Slave_IO_Running") slaveSQLRunning := rowMap.GetString("Slave_SQL_Running") - // if slaveIORunning != "Yes" || slaveSQLRunning != "Yes" { return fmt.Errorf("Replication on %+v is broken: Slave_IO_Running: %s, Slave_SQL_Running: %s. Please make sure replication runs before using gh-ost.", connectionConfig.Key, @@ -153,7 +174,7 @@ func GetInstanceKey(db *gosql.DB) (instanceKey *InstanceKey, err error) { } // GetTableColumns reads column list from given table -func GetTableColumns(db *gosql.DB, databaseName, tableName string) (*sql.ColumnList, error) { +func GetTableColumns(db *gosql.DB, databaseName, tableName string) (*sql.ColumnList, *sql.ColumnList, error) { query := fmt.Sprintf(` show columns from %s.%s `, @@ -161,18 +182,24 @@ func GetTableColumns(db *gosql.DB, databaseName, tableName string) (*sql.ColumnL sql.EscapeName(tableName), ) columnNames := []string{} + virtualColumnNames := []string{} err := sqlutils.QueryRowsMap(db, query, func(rowMap sqlutils.RowMap) error { - columnNames = append(columnNames, rowMap.GetString("Field")) + columnName := rowMap.GetString("Field") + columnNames = append(columnNames, columnName) + if strings.Contains(rowMap.GetString("Extra"), " GENERATED") { + log.Debugf("%s is a generated column", columnName) + virtualColumnNames = append(virtualColumnNames, columnName) + } return nil }) if err != nil { - return nil, err + return nil, nil, err } if len(columnNames) == 0 { - return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out", + return nil, nil, log.Errorf("Found 0 columns on %s.%s. Bailing out", sql.EscapeName(databaseName), sql.EscapeName(tableName), ) } - return sql.NewColumnList(columnNames), nil + return sql.NewColumnList(columnNames), sql.NewColumnList(virtualColumnNames), nil } diff --git a/go/sql/builder.go b/go/sql/builder.go index 251a874..4b019bc 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -15,11 +15,11 @@ type ValueComparisonSign string const ( LessThanComparisonSign ValueComparisonSign = "<" - LessThanOrEqualsComparisonSign = "<=" - EqualsComparisonSign = "=" - GreaterThanOrEqualsComparisonSign = ">=" - GreaterThanComparisonSign = ">" - NotEqualsComparisonSign = "!=" + LessThanOrEqualsComparisonSign ValueComparisonSign = "<=" + EqualsComparisonSign ValueComparisonSign = "=" + GreaterThanOrEqualsComparisonSign ValueComparisonSign = ">=" + GreaterThanComparisonSign ValueComparisonSign = ">" + NotEqualsComparisonSign ValueComparisonSign = "!=" ) // EscapeName will escape a db/table/column/... name by wrapping with backticks. @@ -140,13 +140,12 @@ func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisons := []string{} for i, column := range columns { - // value := values[i] rangeComparison, err := BuildValueComparison(column, value, comparisonSign) if err != nil { return "", explodedArgs, err } - if len(columns[0:i]) > 0 { + if i > 0 { equalitiesComparison, err := BuildEqualsComparison(columns[0:i], values[0:i]) if err != nil { return "", explodedArgs, err @@ -493,6 +492,9 @@ func BuildDMLUpdateQuery(databaseName, tableName string, tableColumns, sharedCol } setClause, err := BuildSetPreparedClause(mappedSharedColumns) + if err != nil { + return "", sharedArgs, uniqueKeyArgs, err + } equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names()) result = fmt.Sprintf(` diff --git a/go/sql/encoding.go b/go/sql/encoding.go index ac38f85..767bd9d 100644 --- a/go/sql/encoding.go +++ b/go/sql/encoding.go @@ -8,6 +8,7 @@ package sql import ( "golang.org/x/text/encoding" "golang.org/x/text/encoding/charmap" + "golang.org/x/text/encoding/simplifiedchinese" ) type charsetEncoding map[string]encoding.Encoding @@ -18,4 +19,5 @@ func init() { charsetEncodingMap = make(map[string]encoding.Encoding) // Begin mappings charsetEncodingMap["latin1"] = charmap.Windows1252 + charsetEncodingMap["gbk"] = simplifiedchinese.GBK } diff --git a/go/sql/parser.go b/go/sql/parser.go index 7114e10..ebb8b38 100644 --- a/go/sql/parser.go +++ b/go/sql/parser.go @@ -12,24 +12,54 @@ import ( ) var ( - sanitizeQuotesRegexp = regexp.MustCompile("('[^']*')") - renameColumnRegexp = regexp.MustCompile(`(?i)\bchange\s+(column\s+|)([\S]+)\s+([\S]+)\s+`) - dropColumnRegexp = regexp.MustCompile(`(?i)\bdrop\s+(column\s+|)([\S]+)$`) + sanitizeQuotesRegexp = regexp.MustCompile("('[^']*')") + renameColumnRegexp = regexp.MustCompile(`(?i)\bchange\s+(column\s+|)([\S]+)\s+([\S]+)\s+`) + dropColumnRegexp = regexp.MustCompile(`(?i)\bdrop\s+(column\s+|)([\S]+)$`) + renameTableRegexp = regexp.MustCompile(`(?i)\brename\s+(to|as)\s+`) + alterTableExplicitSchemaTableRegexps = []*regexp.Regexp{ + // ALTER TABLE `scm`.`tbl` something + regexp.MustCompile(`(?i)\balter\s+table\s+` + "`" + `([^` + "`" + `]+)` + "`" + `[.]` + "`" + `([^` + "`" + `]+)` + "`" + `\s+(.*$)`), + // ALTER TABLE `scm`.tbl something + regexp.MustCompile(`(?i)\balter\s+table\s+` + "`" + `([^` + "`" + `]+)` + "`" + `[.]([\S]+)\s+(.*$)`), + // ALTER TABLE scm.`tbl` something + regexp.MustCompile(`(?i)\balter\s+table\s+([\S]+)[.]` + "`" + `([^` + "`" + `]+)` + "`" + `\s+(.*$)`), + // ALTER TABLE scm.tbl something + regexp.MustCompile(`(?i)\balter\s+table\s+([\S]+)[.]([\S]+)\s+(.*$)`), + } + alterTableExplicitTableRegexps = []*regexp.Regexp{ + // ALTER TABLE `tbl` something + regexp.MustCompile(`(?i)\balter\s+table\s+` + "`" + `([^` + "`" + `]+)` + "`" + `\s+(.*$)`), + // ALTER TABLE tbl something + regexp.MustCompile(`(?i)\balter\s+table\s+([\S]+)\s+(.*$)`), + } ) -type Parser struct { +type AlterTableParser struct { columnRenameMap map[string]string droppedColumns map[string]bool + isRenameTable bool + + alterStatementOptions string + alterTokens []string + + explicitSchema string + explicitTable string } -func NewParser() *Parser { - return &Parser{ +func NewAlterTableParser() *AlterTableParser { + return &AlterTableParser{ columnRenameMap: make(map[string]string), droppedColumns: make(map[string]bool), } } -func (this *Parser) tokenizeAlterStatement(alterStatement string) (tokens []string, err error) { +func NewParserFromAlterStatement(alterStatement string) *AlterTableParser { + parser := NewAlterTableParser() + parser.ParseAlterStatement(alterStatement) + return parser +} + +func (this *AlterTableParser) tokenizeAlterStatement(alterStatement string) (tokens []string, err error) { terminatingQuote := rune(0) f := func(c rune) bool { switch { @@ -56,13 +86,13 @@ func (this *Parser) tokenizeAlterStatement(alterStatement string) (tokens []stri return tokens, nil } -func (this *Parser) sanitizeQuotesFromAlterStatement(alterStatement string) (strippedStatement string) { +func (this *AlterTableParser) sanitizeQuotesFromAlterStatement(alterStatement string) (strippedStatement string) { strippedStatement = alterStatement strippedStatement = sanitizeQuotesRegexp.ReplaceAllString(strippedStatement, "''") return strippedStatement } -func (this *Parser) parseAlterToken(alterToken string) (err error) { +func (this *AlterTableParser) parseAlterToken(alterToken string) (err error) { { // rename allStringSubmatch := renameColumnRegexp.FindAllStringSubmatch(alterToken, -1) @@ -86,19 +116,43 @@ func (this *Parser) parseAlterToken(alterToken string) (err error) { this.droppedColumns[submatch[2]] = true } } - return nil -} - -func (this *Parser) ParseAlterStatement(alterStatement string) (err error) { - alterTokens, _ := this.tokenizeAlterStatement(alterStatement) - for _, alterToken := range alterTokens { - alterToken = this.sanitizeQuotesFromAlterStatement(alterToken) - this.parseAlterToken(alterToken) + { + // rename table + if renameTableRegexp.MatchString(alterToken) { + this.isRenameTable = true + } } return nil } -func (this *Parser) GetNonTrivialRenames() map[string]string { +func (this *AlterTableParser) ParseAlterStatement(alterStatement string) (err error) { + + this.alterStatementOptions = alterStatement + for _, alterTableRegexp := range alterTableExplicitSchemaTableRegexps { + if submatch := alterTableRegexp.FindStringSubmatch(this.alterStatementOptions); len(submatch) > 0 { + this.explicitSchema = submatch[1] + this.explicitTable = submatch[2] + this.alterStatementOptions = submatch[3] + break + } + } + for _, alterTableRegexp := range alterTableExplicitTableRegexps { + if submatch := alterTableRegexp.FindStringSubmatch(this.alterStatementOptions); len(submatch) > 0 { + this.explicitTable = submatch[1] + this.alterStatementOptions = submatch[2] + break + } + } + alterTokens, _ := this.tokenizeAlterStatement(this.alterStatementOptions) + for _, alterToken := range alterTokens { + alterToken = this.sanitizeQuotesFromAlterStatement(alterToken) + this.parseAlterToken(alterToken) + this.alterTokens = append(this.alterTokens, alterToken) + } + return nil +} + +func (this *AlterTableParser) GetNonTrivialRenames() map[string]string { result := make(map[string]string) for column, renamed := range this.columnRenameMap { if column != renamed { @@ -108,10 +162,33 @@ func (this *Parser) GetNonTrivialRenames() map[string]string { return result } -func (this *Parser) HasNonTrivialRenames() bool { +func (this *AlterTableParser) HasNonTrivialRenames() bool { return len(this.GetNonTrivialRenames()) > 0 } -func (this *Parser) DroppedColumnsMap() map[string]bool { +func (this *AlterTableParser) DroppedColumnsMap() map[string]bool { return this.droppedColumns } + +func (this *AlterTableParser) IsRenameTable() bool { + return this.isRenameTable +} +func (this *AlterTableParser) GetExplicitSchema() string { + return this.explicitSchema +} + +func (this *AlterTableParser) HasExplicitSchema() bool { + return this.GetExplicitSchema() != "" +} + +func (this *AlterTableParser) GetExplicitTable() string { + return this.explicitTable +} + +func (this *AlterTableParser) HasExplicitTable() bool { + return this.GetExplicitTable() != "" +} + +func (this *AlterTableParser) GetAlterStatementOptions() string { + return this.alterStatementOptions +} diff --git a/go/sql/parser_test.go b/go/sql/parser_test.go index 3e1d845..79faa63 100644 --- a/go/sql/parser_test.go +++ b/go/sql/parser_test.go @@ -19,17 +19,19 @@ func init() { func TestParseAlterStatement(t *testing.T) { statement := "add column t int, engine=innodb" - parser := NewParser() + parser := NewAlterTableParser() err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.alterStatementOptions, statement) test.S(t).ExpectFalse(parser.HasNonTrivialRenames()) } func TestParseAlterStatementTrivialRename(t *testing.T) { statement := "add column t int, change ts ts timestamp, engine=innodb" - parser := NewParser() + parser := NewAlterTableParser() err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.alterStatementOptions, statement) test.S(t).ExpectFalse(parser.HasNonTrivialRenames()) test.S(t).ExpectEquals(len(parser.columnRenameMap), 1) test.S(t).ExpectEquals(parser.columnRenameMap["ts"], "ts") @@ -37,9 +39,10 @@ func TestParseAlterStatementTrivialRename(t *testing.T) { func TestParseAlterStatementTrivialRenames(t *testing.T) { statement := "add column t int, change ts ts timestamp, CHANGE f `f` float, engine=innodb" - parser := NewParser() + parser := NewAlterTableParser() err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.alterStatementOptions, statement) test.S(t).ExpectFalse(parser.HasNonTrivialRenames()) test.S(t).ExpectEquals(len(parser.columnRenameMap), 2) test.S(t).ExpectEquals(parser.columnRenameMap["ts"], "ts") @@ -58,9 +61,10 @@ func TestParseAlterStatementNonTrivial(t *testing.T) { } for _, statement := range statements { - parser := NewParser() + parser := NewAlterTableParser() err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.alterStatementOptions, statement) renames := parser.GetNonTrivialRenames() test.S(t).ExpectEquals(len(renames), 2) test.S(t).ExpectEquals(renames["i"], "count") @@ -69,7 +73,7 @@ func TestParseAlterStatementNonTrivial(t *testing.T) { } func TestTokenizeAlterStatement(t *testing.T) { - parser := NewParser() + parser := NewAlterTableParser() { alterStatement := "add column t int" tokens, _ := parser.tokenizeAlterStatement(alterStatement) @@ -108,7 +112,7 @@ func TestTokenizeAlterStatement(t *testing.T) { } func TestSanitizeQuotesFromAlterStatement(t *testing.T) { - parser := NewParser() + parser := NewAlterTableParser() { alterStatement := "add column e enum('a','b','c')" strippedStatement := parser.sanitizeQuotesFromAlterStatement(alterStatement) @@ -124,7 +128,7 @@ func TestSanitizeQuotesFromAlterStatement(t *testing.T) { func TestParseAlterStatementDroppedColumns(t *testing.T) { { - parser := NewParser() + parser := NewAlterTableParser() statement := "drop column b" err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) @@ -132,16 +136,17 @@ func TestParseAlterStatementDroppedColumns(t *testing.T) { test.S(t).ExpectTrue(parser.droppedColumns["b"]) } { - parser := NewParser() + parser := NewAlterTableParser() statement := "drop column b, drop key c_idx, drop column `d`" err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.alterStatementOptions, statement) test.S(t).ExpectEquals(len(parser.droppedColumns), 2) test.S(t).ExpectTrue(parser.droppedColumns["b"]) test.S(t).ExpectTrue(parser.droppedColumns["d"]) } { - parser := NewParser() + parser := NewAlterTableParser() statement := "drop column b, drop key c_idx, drop column `d`, drop `e`, drop primary key, drop foreign key fk_1" err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) @@ -151,7 +156,7 @@ func TestParseAlterStatementDroppedColumns(t *testing.T) { test.S(t).ExpectTrue(parser.droppedColumns["e"]) } { - parser := NewParser() + parser := NewAlterTableParser() statement := "drop column b, drop bad statement, add column i int" err := parser.ParseAlterStatement(statement) test.S(t).ExpectNil(err) @@ -159,3 +164,137 @@ func TestParseAlterStatementDroppedColumns(t *testing.T) { test.S(t).ExpectTrue(parser.droppedColumns["b"]) } } + +func TestParseAlterStatementRenameTable(t *testing.T) { + + { + parser := NewAlterTableParser() + statement := "drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectFalse(parser.isRenameTable) + } + { + parser := NewAlterTableParser() + statement := "rename as something_else" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectTrue(parser.isRenameTable) + } + { + parser := NewAlterTableParser() + statement := "drop column b, rename as something_else" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.alterStatementOptions, statement) + test.S(t).ExpectTrue(parser.isRenameTable) + } + { + parser := NewAlterTableParser() + statement := "engine=innodb rename as something_else" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectTrue(parser.isRenameTable) + } + { + parser := NewAlterTableParser() + statement := "rename as something_else, engine=innodb" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectTrue(parser.isRenameTable) + } +} + +func TestParseAlterStatementExplicitTable(t *testing.T) { + + { + parser := NewAlterTableParser() + statement := "drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "") + test.S(t).ExpectEquals(parser.explicitTable, "") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table tbl drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "") + test.S(t).ExpectEquals(parser.explicitTable, "tbl") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table `tbl` drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "") + test.S(t).ExpectEquals(parser.explicitTable, "tbl") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table `scm with spaces`.`tbl` drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "scm with spaces") + test.S(t).ExpectEquals(parser.explicitTable, "tbl") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table `scm`.`tbl with spaces` drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "scm") + test.S(t).ExpectEquals(parser.explicitTable, "tbl with spaces") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table `scm`.tbl drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "scm") + test.S(t).ExpectEquals(parser.explicitTable, "tbl") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table scm.`tbl` drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "scm") + test.S(t).ExpectEquals(parser.explicitTable, "tbl") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table scm.tbl drop column b" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "scm") + test.S(t).ExpectEquals(parser.explicitTable, "tbl") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + } + { + parser := NewAlterTableParser() + statement := "alter table scm.tbl drop column b, add index idx(i)" + err := parser.ParseAlterStatement(statement) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(parser.explicitSchema, "scm") + test.S(t).ExpectEquals(parser.explicitTable, "tbl") + test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b, add index idx(i)") + test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b", "add index idx(i)"})) + } +} diff --git a/go/sql/types.go b/go/sql/types.go index 15a99ff..ef83819 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -15,18 +15,18 @@ import ( type ColumnType int const ( - UnknownColumnType ColumnType = iota - TimestampColumnType = iota - DateTimeColumnType = iota - EnumColumnType = iota - MediumIntColumnType = iota - JSONColumnType = iota - FloatColumnType = iota + UnknownColumnType ColumnType = iota + TimestampColumnType + DateTimeColumnType + EnumColumnType + MediumIntColumnType + JSONColumnType + FloatColumnType ) const maxMediumintUnsigned int32 = 16777215 -type TimezoneConvertion struct { +type TimezoneConversion struct { ToTimezone string } @@ -35,7 +35,7 @@ type Column struct { IsUnsigned bool Charset string Type ColumnType - timezoneConversion *TimezoneConvertion + timezoneConversion *TimezoneConversion } func (this *Column) convertArg(arg interface{}) interface{} { @@ -172,7 +172,7 @@ func (this *ColumnList) GetColumnType(columnName string) ColumnType { } func (this *ColumnList) SetConvertDatetimeToTimestamp(columnName string, toTimezone string) { - this.GetColumn(columnName).timezoneConversion = &TimezoneConvertion{ToTimezone: toTimezone} + this.GetColumn(columnName).timezoneConversion = &TimezoneConversion{ToTimezone: toTimezone} } func (this *ColumnList) HasTimezoneConversion(columnName string) bool { diff --git a/localtests/autoinc-zero-value/create.sql b/localtests/autoinc-zero-value/create.sql new file mode 100644 index 0000000..ba08bd4 --- /dev/null +++ b/localtests/autoinc-zero-value/create.sql @@ -0,0 +1,9 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + primary key(id) +) auto_increment=1; + +set session sql_mode='NO_AUTO_VALUE_ON_ZERO'; +insert into gh_ost_test values (0, 23); diff --git a/localtests/bigint-change-nullable/create.sql b/localtests/bigint-change-nullable/create.sql new file mode 100644 index 0000000..f4f0548 --- /dev/null +++ b/localtests/bigint-change-nullable/create.sql @@ -0,0 +1,21 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id bigint auto_increment, + val bigint not null, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 18446744073709551615); + insert into gh_ost_test values (null, 18446744073709551614); + insert into gh_ost_test values (null, 18446744073709551613); +end ;; diff --git a/localtests/bigint-change-nullable/extra_args b/localtests/bigint-change-nullable/extra_args new file mode 100644 index 0000000..784d522 --- /dev/null +++ b/localtests/bigint-change-nullable/extra_args @@ -0,0 +1 @@ +--alter="change val val bigint" diff --git a/localtests/bit-add/create.sql b/localtests/bit-add/create.sql new file mode 100644 index 0000000..d58934a --- /dev/null +++ b/localtests/bit-add/create.sql @@ -0,0 +1,20 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11); + insert into gh_ost_test values (null, 13); +end ;; diff --git a/localtests/bit-add/extra_args b/localtests/bit-add/extra_args new file mode 100644 index 0000000..9d26250 --- /dev/null +++ b/localtests/bit-add/extra_args @@ -0,0 +1 @@ +--alter="add column is_good bit null default 0" diff --git a/localtests/bit-add/ghost_columns b/localtests/bit-add/ghost_columns new file mode 100644 index 0000000..b464f06 --- /dev/null +++ b/localtests/bit-add/ghost_columns @@ -0,0 +1 @@ +id, i diff --git a/localtests/bit-add/orig_columns b/localtests/bit-add/orig_columns new file mode 100644 index 0000000..b464f06 --- /dev/null +++ b/localtests/bit-add/orig_columns @@ -0,0 +1 @@ +id, i diff --git a/localtests/bit-dml/create.sql b/localtests/bit-dml/create.sql new file mode 100644 index 0000000..c7c26af --- /dev/null +++ b/localtests/bit-dml/create.sql @@ -0,0 +1,24 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + is_good bit null default 0, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11, 0); + insert into gh_ost_test values (null, 13, 1); + insert into gh_ost_test values (null, 17, 1); + + update gh_ost_test set is_good=0 where i=13 order by id desc limit 1; +end ;; diff --git a/localtests/bit-dml/extra_args b/localtests/bit-dml/extra_args new file mode 100644 index 0000000..a3abab2 --- /dev/null +++ b/localtests/bit-dml/extra_args @@ -0,0 +1 @@ +--alter="modify column is_good bit not null default 0" --approve-renamed-columns diff --git a/localtests/convert-utf8mb4/create.sql b/localtests/convert-utf8mb4/create.sql new file mode 100644 index 0000000..05f1a13 --- /dev/null +++ b/localtests/convert-utf8mb4/create.sql @@ -0,0 +1,31 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + t varchar(128) charset utf8 collate utf8_general_ci, + tl varchar(128) charset latin1 not null, + ta varchar(128) charset ascii not null, + primary key(id) +) auto_increment=1; + +insert into gh_ost_test values (null, 'átesting'); + + +insert into gh_ost_test values (null, 'Hello world, Καλημέρα κόσμε, コンニチハ', 'átesting0', 'initial'); + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, md5(rand()), 'átesting-a', 'a'); + insert into gh_ost_test values (null, 'novo proprietário', 'átesting-b', 'b'); + insert into gh_ost_test values (null, '2H₂ + O₂ ⇌ 2H₂O, R = 4.7 kΩ, ⌀ 200 mm', 'átesting-c', 'c'); + insert into gh_ost_test values (null, 'usuário', 'átesting-x', 'x'); + + delete from gh_ost_test where ta='x' order by id desc limit 1; +end ;; diff --git a/localtests/convert-utf8mb4/extra_args b/localtests/convert-utf8mb4/extra_args new file mode 100644 index 0000000..2b2f64d --- /dev/null +++ b/localtests/convert-utf8mb4/extra_args @@ -0,0 +1 @@ +--alter='convert to character set utf8mb4' diff --git a/localtests/datetime-1970/create.sql b/localtests/datetime-1970/create.sql new file mode 100644 index 0000000..a1914c6 --- /dev/null +++ b/localtests/datetime-1970/create.sql @@ -0,0 +1,27 @@ +set session time_zone='+00:00'; + +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + create_time timestamp NULL DEFAULT '0000-00-00 00:00:00', + update_time timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + counter int(10) unsigned DEFAULT NULL, + primary key(id) +) auto_increment=1; + +set session time_zone='+00:00'; +insert into gh_ost_test values (1, '0000-00-00 00:00:00', now(), 0); + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + set session time_zone='+00:00'; + update gh_ost_test set counter = counter + 1 where id = 1; +end ;; diff --git a/localtests/datetime-1970/extra_args b/localtests/datetime-1970/extra_args new file mode 100644 index 0000000..453f761 --- /dev/null +++ b/localtests/datetime-1970/extra_args @@ -0,0 +1 @@ +--alter='add column name varchar(1)' diff --git a/localtests/datetime-1970/ghost_columns b/localtests/datetime-1970/ghost_columns new file mode 100644 index 0000000..581038f --- /dev/null +++ b/localtests/datetime-1970/ghost_columns @@ -0,0 +1 @@ +id, create_time, update_time, counter diff --git a/localtests/datetime-1970/orig_columns b/localtests/datetime-1970/orig_columns new file mode 100644 index 0000000..581038f --- /dev/null +++ b/localtests/datetime-1970/orig_columns @@ -0,0 +1 @@ +id, create_time, update_time, counter diff --git a/localtests/datetime-1970/sql_mode b/localtests/datetime-1970/sql_mode new file mode 100644 index 0000000..e69de29 diff --git a/localtests/datetime-submillis-zeroleading/ignore_versions b/localtests/datetime-submillis-zeroleading/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/datetime-submillis-zeroleading/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/datetime-submillis/create.sql b/localtests/datetime-submillis/create.sql index b4e0b0b..6c04adb 100644 --- a/localtests/datetime-submillis/create.sql +++ b/localtests/datetime-submillis/create.sql @@ -17,7 +17,7 @@ create event gh_ost_test starts current_timestamp ends current_timestamp + interval 60 second on completion not preserve - enable + disable on slave do begin insert into gh_ost_test values (null, 11, now(), now(), now(), 0); diff --git a/localtests/datetime-submillis/ignore_versions b/localtests/datetime-submillis/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/datetime-submillis/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/datetime-to-timestamp-pk-fail/ignore_versions b/localtests/datetime-to-timestamp-pk-fail/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/datetime-to-timestamp-pk-fail/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/datetime/ignore_versions b/localtests/datetime/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/datetime/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/decimal/create.sql b/localtests/decimal/create.sql new file mode 100644 index 0000000..248c86a --- /dev/null +++ b/localtests/decimal/create.sql @@ -0,0 +1,23 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + dec0 decimal(65,30) unsigned NOT NULL DEFAULT '0.000000000000000000000000000000', + dec1 decimal(65,30) unsigned NOT NULL DEFAULT '1.000000000000000000000000000000', + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 0.0, 0.0); + insert into gh_ost_test values (null, 2.0, 4.0); + insert into gh_ost_test values (null, 99999999999999999999999999999999999.000, 6.0); + update gh_ost_test set dec1=4.5 where dec2=4.0 order by id desc limit 1; +end ;; diff --git a/localtests/fail-rename-table/create.sql b/localtests/fail-rename-table/create.sql new file mode 100644 index 0000000..5bb45f2 --- /dev/null +++ b/localtests/fail-rename-table/create.sql @@ -0,0 +1,22 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + ts timestamp, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11, now()); + insert into gh_ost_test values (null, 13, now()); + insert into gh_ost_test values (null, 17, now()); +end ;; diff --git a/localtests/fail-rename-table/expect_failure b/localtests/fail-rename-table/expect_failure new file mode 100644 index 0000000..e444c17 --- /dev/null +++ b/localtests/fail-rename-table/expect_failure @@ -0,0 +1 @@ +ALTER statement seems to RENAME the table diff --git a/localtests/fail-rename-table/extra_args b/localtests/fail-rename-table/extra_args new file mode 100644 index 0000000..28a7587 --- /dev/null +++ b/localtests/fail-rename-table/extra_args @@ -0,0 +1 @@ +--alter="rename as something_else" diff --git a/localtests/fail-update-pk-column/create.sql b/localtests/fail-update-pk-column/create.sql new file mode 100644 index 0000000..5cc1d37 --- /dev/null +++ b/localtests/fail-update-pk-column/create.sql @@ -0,0 +1,52 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + primary key(id) +) auto_increment=1; + +insert into gh_ost_test values (null, 101); +insert into gh_ost_test values (null, 102); +insert into gh_ost_test values (null, 103); +insert into gh_ost_test values (null, 104); +insert into gh_ost_test values (null, 105); +insert into gh_ost_test values (null, 106); +insert into gh_ost_test values (null, 107); +insert into gh_ost_test values (null, 108); +insert into gh_ost_test values (null, 109); +insert into gh_ost_test values (null, 110); +insert into gh_ost_test values (null, 111); +insert into gh_ost_test values (null, 112); +insert into gh_ost_test values (null, 113); +insert into gh_ost_test values (null, 114); +insert into gh_ost_test values (null, 115); +insert into gh_ost_test values (null, 116); +insert into gh_ost_test values (null, 117); +insert into gh_ost_test values (null, 118); +insert into gh_ost_test values (null, 119); +insert into gh_ost_test values (null, 120); +insert into gh_ost_test values (null, 121); +insert into gh_ost_test values (null, 122); +insert into gh_ost_test values (null, 123); +insert into gh_ost_test values (null, 124); +insert into gh_ost_test values (null, 125); +insert into gh_ost_test values (null, 126); +insert into gh_ost_test values (null, 127); +insert into gh_ost_test values (null, 128); +insert into gh_ost_test values (null, 129); + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + interval 3 second + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + update gh_ost_test set id=-2 where id=21; + update gh_ost_test set id=55 where id=22; + update gh_ost_test set id=23 where id=23; + update gh_ost_test set i=5024 where id=24; +end ;; diff --git a/localtests/gbk-charset/create.sql b/localtests/gbk-charset/create.sql new file mode 100644 index 0000000..08e1fae --- /dev/null +++ b/localtests/gbk-charset/create.sql @@ -0,0 +1,25 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int(11) NOT NULL AUTO_INCREMENT, + name varchar(512) DEFAULT NULL, + v varchar(255) DEFAULT NULL COMMENT '添加普通列测试', + PRIMARY KEY (id) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=gbk; + +insert into gh_ost_test values (null, 'gbk-test-initial', '添加普通列测试-添加普通列测试'); +insert into gh_ost_test values (null, 'gbk-test-initial', '添加普通列测试-添加普通列测试'); + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test (name) values ('gbk-test-default'); + insert into gh_ost_test values (null, 'gbk-test', '添加普通列测试-添加普通列测试'); + update gh_ost_test set v='添加普通列测试' where v='添加普通列测试-添加普通列测试' order by id desc limit 1; +end ;; diff --git a/localtests/gbk-charset/extra_args b/localtests/gbk-charset/extra_args new file mode 100644 index 0000000..e69de29 diff --git a/localtests/generated-columns-add57/create.sql b/localtests/generated-columns-add57/create.sql new file mode 100644 index 0000000..ef795ba --- /dev/null +++ b/localtests/generated-columns-add57/create.sql @@ -0,0 +1,29 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + a int not null, + b int not null, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test (id, a, b) values (null, 2,3); + insert into gh_ost_test (id, a, b) values (null, 2,4); + insert into gh_ost_test (id, a, b) values (null, 2,5); + insert into gh_ost_test (id, a, b) values (null, 2,6); + insert into gh_ost_test (id, a, b) values (null, 2,7); + insert into gh_ost_test (id, a, b) values (null, 2,8); + insert into gh_ost_test (id, a, b) values (null, 2,9); + insert into gh_ost_test (id, a, b) values (null, 2,0); + insert into gh_ost_test (id, a, b) values (null, 2,1); + insert into gh_ost_test (id, a, b) values (null, 2,2); +end ;; diff --git a/localtests/generated-columns-add57/extra_args b/localtests/generated-columns-add57/extra_args new file mode 100644 index 0000000..b2bf5bc --- /dev/null +++ b/localtests/generated-columns-add57/extra_args @@ -0,0 +1 @@ +--alter="add column sum_ab int as (a + b) virtual not null" diff --git a/localtests/generated-columns-add57/ghost_columns b/localtests/generated-columns-add57/ghost_columns new file mode 100644 index 0000000..bd17155 --- /dev/null +++ b/localtests/generated-columns-add57/ghost_columns @@ -0,0 +1 @@ +id, a, b diff --git a/localtests/generated-columns-add57/ignore_versions b/localtests/generated-columns-add57/ignore_versions new file mode 100644 index 0000000..b6de5f8 --- /dev/null +++ b/localtests/generated-columns-add57/ignore_versions @@ -0,0 +1 @@ +(5.5|5.6) diff --git a/localtests/generated-columns-add57/order_by b/localtests/generated-columns-add57/order_by new file mode 100644 index 0000000..074d1ee --- /dev/null +++ b/localtests/generated-columns-add57/order_by @@ -0,0 +1 @@ +id diff --git a/localtests/generated-columns-add57/orig_columns b/localtests/generated-columns-add57/orig_columns new file mode 100644 index 0000000..bd17155 --- /dev/null +++ b/localtests/generated-columns-add57/orig_columns @@ -0,0 +1 @@ +id, a, b diff --git a/localtests/generated-columns-rename57/create.sql b/localtests/generated-columns-rename57/create.sql new file mode 100644 index 0000000..e244ca3 --- /dev/null +++ b/localtests/generated-columns-rename57/create.sql @@ -0,0 +1,30 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + a int not null, + b int not null, + sum_ab int as (a + b) virtual not null, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test (id, a, b) values (null, 2,3); + insert into gh_ost_test (id, a, b) values (null, 2,4); + insert into gh_ost_test (id, a, b) values (null, 2,5); + insert into gh_ost_test (id, a, b) values (null, 2,6); + insert into gh_ost_test (id, a, b) values (null, 2,7); + insert into gh_ost_test (id, a, b) values (null, 2,8); + insert into gh_ost_test (id, a, b) values (null, 2,9); + insert into gh_ost_test (id, a, b) values (null, 2,0); + insert into gh_ost_test (id, a, b) values (null, 2,1); + insert into gh_ost_test (id, a, b) values (null, 2,2); +end ;; diff --git a/localtests/generated-columns-rename57/extra_args b/localtests/generated-columns-rename57/extra_args new file mode 100644 index 0000000..6a19098 --- /dev/null +++ b/localtests/generated-columns-rename57/extra_args @@ -0,0 +1 @@ +--alter="change sum_ab total_ab int as (a + b) virtual not null" --approve-renamed-columns diff --git a/localtests/generated-columns-rename57/ignore_versions b/localtests/generated-columns-rename57/ignore_versions new file mode 100644 index 0000000..b6de5f8 --- /dev/null +++ b/localtests/generated-columns-rename57/ignore_versions @@ -0,0 +1 @@ +(5.5|5.6) diff --git a/localtests/generated-columns57/create.sql b/localtests/generated-columns57/create.sql new file mode 100644 index 0000000..e244ca3 --- /dev/null +++ b/localtests/generated-columns57/create.sql @@ -0,0 +1,30 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + a int not null, + b int not null, + sum_ab int as (a + b) virtual not null, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test (id, a, b) values (null, 2,3); + insert into gh_ost_test (id, a, b) values (null, 2,4); + insert into gh_ost_test (id, a, b) values (null, 2,5); + insert into gh_ost_test (id, a, b) values (null, 2,6); + insert into gh_ost_test (id, a, b) values (null, 2,7); + insert into gh_ost_test (id, a, b) values (null, 2,8); + insert into gh_ost_test (id, a, b) values (null, 2,9); + insert into gh_ost_test (id, a, b) values (null, 2,0); + insert into gh_ost_test (id, a, b) values (null, 2,1); + insert into gh_ost_test (id, a, b) values (null, 2,2); +end ;; diff --git a/localtests/generated-columns57/ignore_versions b/localtests/generated-columns57/ignore_versions new file mode 100644 index 0000000..b6de5f8 --- /dev/null +++ b/localtests/generated-columns57/ignore_versions @@ -0,0 +1 @@ +(5.5|5.6) diff --git a/localtests/geometry57/create.sql b/localtests/geometry57/create.sql new file mode 100644 index 0000000..6dd64c6 --- /dev/null +++ b/localtests/geometry57/create.sql @@ -0,0 +1,21 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + g geometry, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, ST_GeomFromText('POINT(1 1)')); + insert into gh_ost_test values (null, ST_GeomFromText('POINT(2 2)')); + insert into gh_ost_test values (null, ST_GeomFromText('POINT(3 3)')); +end ;; diff --git a/localtests/geometry57/ignore_versions b/localtests/geometry57/ignore_versions new file mode 100644 index 0000000..b6de5f8 --- /dev/null +++ b/localtests/geometry57/ignore_versions @@ -0,0 +1 @@ +(5.5|5.6) diff --git a/localtests/json57/ignore_versions b/localtests/json57/ignore_versions new file mode 100644 index 0000000..b6de5f8 --- /dev/null +++ b/localtests/json57/ignore_versions @@ -0,0 +1 @@ +(5.5|5.6) diff --git a/localtests/json57dml/create.sql b/localtests/json57dml/create.sql index da8cd57..0e76b2e 100644 --- a/localtests/json57dml/create.sql +++ b/localtests/json57dml/create.sql @@ -20,6 +20,7 @@ begin insert into gh_ost_test (id, i, j) values (null, 11, '"sometext"'); insert into gh_ost_test (id, i, j) values (null, 13, '{"key":"val"}'); insert into gh_ost_test (id, i, j) values (null, 17, '{"is-it": true, "count": 3, "elements": []}'); + insert into gh_ost_test (id, i, j) values (null, 19, '{"text":"Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo ligula eget dolor. Aenean massa. Cum sociis natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Donec quam felis, ultricies nec, pellentesque eu, pretium quis, sem. Nulla consequat massa quis enim. Donec pede justo, fringilla vel, aliquet nec, vulputate eget, arcu. In enim justo, rhoncus ut, imperdiet a, venenatis vitae, justo. Nullam dictum felis eu pede mollis pretium. Integer tincidunt. Cras dapibus. Vivamus elementum semper nisi. Aenean vulputate eleifend tellus. Aenean leo ligula, porttitor eu, consequat vitae, eleifend ac, enim. Aliquam lorem ante, dapibus in, viverra quis, feugiat a, tellus. Phasellus viverra nulla ut metus varius laoreet. Quisque rutrum. Aenean imperdiet. Etiam ultricies nisi vel augue. Curabitur ullamcorper ultricies nisi. Nam eget dui. Etiam rhoncus. Maecenas tempus, tellus eget condimentum rhoncus, sem quam semper libero, sit amet adipiscing sem neque sed ipsum. Nam quam nunc, blandit vel, luctus pulvinar, hendrerit id, lorem. Maecenas nec odio et ante tincidunt tempus. Donec vitae sapien ut libero venenatis faucibus. Nullam quis ante. Etiam sit amet orci eget eros faucibus tincidunt. Duis leo. Sed fringilla mauris sit amet nibh. Donec sodales sagittis magna. Sed consequat, leo eget bibendum sodales, augue velit cursus nunc, quis gravida magna mi a libero. Fusce vulputate eleifend sapien. Vestibulum purus quam, scelerisque ut, mollis sed, nonummy id, metus. Nullam accumsan lorem in dui. Cras ultricies mi eu turpis hendrerit fringilla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; In ac dui quis mi consectetuer lacinia. Nam pretium turpis et arcu. Duis arcu tortor, suscipit eget, imperdiet nec, imperdiet iaculis, ipsum. Sed aliquam ultrices mauris. Integer ante arcu, accumsan a, consectetuer eget, posuere ut, mauris. Praesent adipiscing. Phasellus ullamcorper ipsum rutrum nunc. Nunc nonummy metus. Vestibulum volutpat pretium libero. Cras id dui. Aenean ut eros et nisl sagittis vestibulum. Nullam nulla eros, ultricies sit amet, nonummy id, imperdiet feugiat, pede. Sed lectus. Donec mollis hendrerit risus. Phasellus nec sem in justo pellentesque facilisis. Etiam imperdiet imperdiet orci. Nunc nec neque. Phasellus leo dolor, tempus non, auctor et, hendrerit quis, nisi. Curabitur ligula sapien, tincidunt non, euismod vitae, posuere imperdiet, leo. Maecenas malesuada. Praesent congue erat at massa. Sed cursus turpis vitae tortor. Donec posuere vulputate arcu. Phasellus accumsan cursus velit. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Sed aliquam, nisi quis porttitor congue, elit erat euismod orci, ac"}'); update gh_ost_test set j = '{"updated": 11}', updated = 1 where i = 11 and updated = 0; update gh_ost_test set j = json_set(j, '$.count', 13, '$.id', id), updated = 1 where i = 13 and updated = 0; diff --git a/localtests/json57dml/ignore_versions b/localtests/json57dml/ignore_versions new file mode 100644 index 0000000..b6de5f8 --- /dev/null +++ b/localtests/json57dml/ignore_versions @@ -0,0 +1 @@ +(5.5|5.6) diff --git a/localtests/keyword-column/create.sql b/localtests/keyword-column/create.sql new file mode 100644 index 0000000..9371238 --- /dev/null +++ b/localtests/keyword-column/create.sql @@ -0,0 +1,13 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + color varchar(32), + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; + +insert into gh_ost_test values (null, 11, 'red'); +insert into gh_ost_test values (null, 13, 'green'); +insert into gh_ost_test values (null, 17, 'blue'); diff --git a/localtests/keyword-column/extra_args b/localtests/keyword-column/extra_args new file mode 100644 index 0000000..5d73843 --- /dev/null +++ b/localtests/keyword-column/extra_args @@ -0,0 +1 @@ +--alter='add column `index` int unsigned' \ diff --git a/localtests/keyword-column/ghost_columns b/localtests/keyword-column/ghost_columns new file mode 100644 index 0000000..f5941f3 --- /dev/null +++ b/localtests/keyword-column/ghost_columns @@ -0,0 +1 @@ +id, i, color diff --git a/localtests/keyword-column/orig_columns b/localtests/keyword-column/orig_columns new file mode 100644 index 0000000..f5941f3 --- /dev/null +++ b/localtests/keyword-column/orig_columns @@ -0,0 +1 @@ +id, i, color diff --git a/localtests/latin1text/create.sql b/localtests/latin1text/create.sql new file mode 100644 index 0000000..58837cb --- /dev/null +++ b/localtests/latin1text/create.sql @@ -0,0 +1,25 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + t text charset latin1 collate latin1_swedish_ci, + primary key(id) +) auto_increment=1 charset latin1 collate latin1_swedish_ci; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, md5(rand())); + insert into gh_ost_test values (null, 'átesting'); + insert into gh_ost_test values (null, 'ádelete'); + insert into gh_ost_test values (null, 'testátest'); + update gh_ost_test set t='áupdated' order by id desc limit 1; + update gh_ost_test set t='áupdated1' where t='áupdated' order by id desc limit 1; + delete from gh_ost_test where t='ádelete'; +end ;; diff --git a/localtests/spatial57/create.sql b/localtests/spatial57/create.sql new file mode 100644 index 0000000..bb213d5 --- /dev/null +++ b/localtests/spatial57/create.sql @@ -0,0 +1,22 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + g geometry, + pt point, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, ST_GeomFromText('POINT(1 1)'), POINT(10,10)); + insert into gh_ost_test values (null, ST_GeomFromText('POINT(2 2)'), POINT(20,20)); + insert into gh_ost_test values (null, ST_GeomFromText('POINT(3 3)'), POINT(30,30)); +end ;; diff --git a/localtests/spatial57/ignore_versions b/localtests/spatial57/ignore_versions new file mode 100644 index 0000000..b6de5f8 --- /dev/null +++ b/localtests/spatial57/ignore_versions @@ -0,0 +1 @@ +(5.5|5.6) diff --git a/localtests/swap-pk-uk/ignore_versions b/localtests/swap-pk-uk/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/swap-pk-uk/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/swap-uk-uk/create.sql b/localtests/swap-uk-uk/create.sql index 5fcbf32..30c1542 100644 --- a/localtests/swap-uk-uk/create.sql +++ b/localtests/swap-uk-uk/create.sql @@ -1,8 +1,8 @@ drop table if exists gh_ost_test; create table gh_ost_test ( - id bigint, + id bigint not null, i int not null, - ts timestamp(6), + ts timestamp(6) not null, unique key id_uidx(id), unique key its_uidx(i, ts) ) ; diff --git a/localtests/swap-uk-uk/ignore_versions b/localtests/swap-uk-uk/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/swap-uk-uk/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/test.sh b/localtests/test.sh index b901ce3..d4b3f17 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -9,16 +9,31 @@ tests_path=$(dirname $0) test_logfile=/tmp/gh-ost-test.log -ghost_binary=/tmp/gh-ost-test +default_ghost_binary=/tmp/gh-ost-test +ghost_binary="" exec_command_file=/tmp/gh-ost-test.bash -orig_content_output_file=/gh-ost-test.orig.content.csv -ghost_content_output_file=/gh-ost-test.ghost.content.csv -test_pattern="${1:-.}" +orig_content_output_file=/tmp/gh-ost-test.orig.content.csv +ghost_content_output_file=/tmp/gh-ost-test.ghost.content.csv +throttle_flag_file=/tmp/gh-ost-test.ghost.throttle.flag master_host= master_port= replica_host= replica_port= +original_sql_mode= + +OPTIND=1 +while getopts "b:" OPTION +do + case $OPTION in + b) + ghost_binary="$OPTARG" + ;; + esac +done +shift $((OPTIND-1)) + +test_pattern="${1:-.}" verify_master_and_replica() { if [ "$(gh-ost-test-mysql-master -e "select 1" -ss)" != "1" ] ; then @@ -26,6 +41,18 @@ verify_master_and_replica() { exit 1 fi read master_host master_port <<< $(gh-ost-test-mysql-master -e "select @@hostname, @@port" -ss) + [ "$master_host" == "$(hostname)" ] && master_host="127.0.0.1" + echo "# master verified at $master_host:$master_port" + if ! gh-ost-test-mysql-master -e "set global event_scheduler := 1" ; then + echo "Cannot enable event_scheduler on master" + exit 1 + fi + original_sql_mode="$(gh-ost-test-mysql-master -e "select @@global.sql_mode" -s -s)" + echo "sql_mode on master is ${original_sql_mode}" + + echo "Gracefully sleeping for 3 seconds while replica is setting up..." + sleep 3 + if [ "$(gh-ost-test-mysql-replica -e "select 1" -ss)" != "1" ] ; then echo "Cannot verify gh-ost-test-mysql-replica" exit 1 @@ -35,6 +62,8 @@ verify_master_and_replica() { exit 1 fi read replica_host replica_port <<< $(gh-ost-test-mysql-replica -e "select @@hostname, @@port" -ss) + [ "$replica_host" == "$(hostname)" ] && replica_host="127.0.0.1" + echo "# replica verified at $replica_host:$replica_port" } exec_cmd() { @@ -66,11 +95,26 @@ test_single() { local test_name test_name="$1" + if [ -f $tests_path/$test_name/ignore_versions ] ; then + ignore_versions=$(cat $tests_path/$test_name/ignore_versions) + mysql_version=$(gh-ost-test-mysql-master -s -s -e "select @@version") + if echo "$mysql_version" | egrep -q "^${ignore_versions}" ; then + echo -n "Skipping: $test_name" + return 0 + fi + fi + echo -n "Testing: $test_name" echo_dot start_replication echo_dot + + if [ -f $tests_path/$test_name/sql_mode ] ; then + gh-ost-test-mysql-master --default-character-set=utf8mb4 test -e "set @@global.sql_mode='$(cat $tests_path/$test_name/sql_mode)'" + gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "set @@global.sql_mode='$(cat $tests_path/$test_name/sql_mode)'" + fi + gh-ost-test-mysql-master --default-character-set=utf8mb4 test < $tests_path/$test_name/create.sql extra_args="" @@ -98,6 +142,7 @@ test_single() { --password=gh-ost \ --host=$replica_host \ --port=$replica_port \ + --assume-master-host=${master_host}:${master_port} --database=test \ --table=gh_ost_test \ --alter='engine=innodb' \ @@ -106,10 +151,11 @@ test_single() { --initially-drop-old-table \ --initially-drop-ghost-table \ --throttle-query='select timestampdiff(second, min(last_update), now()) < 5 from _gh_ost_test_ghc' \ + --throttle-flag-file=$throttle_flag_file \ --serve-socket-file=/tmp/gh-ost.test.sock \ --initially-drop-socket-file \ --test-on-replica \ - --default-retries=1 \ + --default-retries=3 \ --chunk-size=10 \ --verbose \ --debug \ @@ -122,6 +168,11 @@ test_single() { execution_result=$? + if [ -f $tests_path/$test_name/sql_mode ] ; then + gh-ost-test-mysql-master --default-character-set=utf8mb4 test -e "set @@global.sql_mode='${original_sql_mode}'" + gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "set @@global.sql_mode='${original_sql_mode}'" + fi + if [ -f $tests_path/$test_name/destroy.sql ] ; then gh-ost-test-mysql-master --default-character-set=utf8mb4 test < $tests_path/$test_name/destroy.sql fi @@ -148,7 +199,8 @@ test_single() { if [ $execution_result -ne 0 ] ; then echo - echo "ERROR $test_name execution failure. cat $test_logfile" + echo "ERROR $test_name execution failure. cat $test_logfile:" + cat $test_logfile return 1 fi @@ -164,13 +216,24 @@ test_single() { diff $orig_content_output_file $ghost_content_output_file echo "diff $orig_content_output_file $ghost_content_output_file" + return 1 fi } build_binary() { echo "Building" + rm -f $default_ghost_binary + [ "$ghost_binary" == "" ] && ghost_binary="$default_ghost_binary" + if [ -f "$ghost_binary" ] ; then + echo "Using binary: $ghost_binary" + return 0 + fi go build -o $ghost_binary go/cmd/gh-ost/main.go + if [ $? -ne 0 ] ; then + echo "Build failure" + exit 1 + fi } test_all() { diff --git a/localtests/timestamp-to-datetime/ignore_versions b/localtests/timestamp-to-datetime/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/timestamp-to-datetime/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/timestamp/ignore_versions b/localtests/timestamp/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/timestamp/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/tz-datetime-ts/ignore_versions b/localtests/tz-datetime-ts/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/tz-datetime-ts/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/tz/ignore_versions b/localtests/tz/ignore_versions new file mode 100644 index 0000000..7acd3f0 --- /dev/null +++ b/localtests/tz/ignore_versions @@ -0,0 +1 @@ +(5.5) diff --git a/localtests/varbinary/create.sql b/localtests/varbinary/create.sql new file mode 100644 index 0000000..9d85200 --- /dev/null +++ b/localtests/varbinary/create.sql @@ -0,0 +1,40 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id binary(16) NOT NULL, + info varchar(255) COLLATE utf8_unicode_ci NOT NULL, + data binary(8) NOT NULL, + primary key (id), + unique key info_uidx (info) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + replace into gh_ost_test (id, info, data) values (X'12ffffffffffffffffffffffffffff00', 'item 1a', X'12ffffffffffffff'); + replace into gh_ost_test (id, info, data) values (X'34ffffffffffffffffffffffffffffff', 'item 3a', X'34ffffffffffffff'); + replace into gh_ost_test (id, info, data) values (X'90ffffffffffffffffffffffffffffff', 'item 9a', X'90ffffffffffff00'); + + DELETE FROM gh_ost_test WHERE id = X'11ffffffffffffffffffffffffffff00'; + UPDATE gh_ost_test SET info = 'item 2++' WHERE id = X'22ffffffffffffffffffffffffffff00'; + UPDATE gh_ost_test SET info = 'item 3++', data = X'33ffffffffffff00' WHERE id = X'33ffffffffffffffffffffffffffffff'; + DELETE FROM gh_ost_test WHERE id = X'44ffffffffffffffffffffffffffffff'; + UPDATE gh_ost_test SET info = 'item 5++', data = X'55ffffffffffffee' WHERE id = X'55ffffffffffffffffffffffffffffff'; + INSERT INTO gh_ost_test (id, info, data) VALUES (X'66ffffffffffffffffffffffffffff00', 'item 6', X'66ffffffffffffff'); + INSERT INTO gh_ost_test (id, info, data) VALUES (X'77ffffffffffffffffffffffffffffff', 'item 7', X'77ffffffffffff00'); + INSERT INTO gh_ost_test (id, info, data) VALUES (X'88ffffffffffffffffffffffffffffff', 'item 8', X'88ffffffffffffff'); +end ;; + +INSERT INTO gh_ost_test (id, info, data) VALUES + (X'11ffffffffffffffffffffffffffff00', 'item 1', X'11ffffffffffffff'), -- id ends in 00 + (X'22ffffffffffffffffffffffffffff00', 'item 2', X'22ffffffffffffff'), -- id ends in 00 + (X'33ffffffffffffffffffffffffffffff', 'item 3', X'33ffffffffffffff'), + (X'44ffffffffffffffffffffffffffffff', 'item 4', X'44ffffffffffffff'), + (X'55ffffffffffffffffffffffffffffff', 'item 5', X'55ffffffffffffff'), + (X'99ffffffffffffffffffffffffffffff', 'item 9', X'99ffffffffffff00'); -- data ends in 00 diff --git a/script/bootstrap b/script/bootstrap index 6ac885b..573313a 100755 --- a/script/bootstrap +++ b/script/bootstrap @@ -4,6 +4,7 @@ set -e # Make sure we have the version of Go we want to depend on, either from the # system or one we grab ourselves. +# If executing from within Dockerfile then this assumption is inherently true, since we use a `golang` docker image. . script/ensure-go-installed # Since we want to be able to build this outside of GOPATH, we set it diff --git a/script/build-deploy-tarball b/script/build-deploy-tarball new file mode 100755 index 0000000..dc28b43 --- /dev/null +++ b/script/build-deploy-tarball @@ -0,0 +1,35 @@ +#!/bin/sh + +set -e + +script/build + +# Get a fresh directory and make sure to delete it afterwards +build_dir=tmp/build +rm -rf $build_dir +mkdir -p $build_dir +trap "rm -rf $build_dir" EXIT + +commit_sha=$(git rev-parse HEAD) + +if [ $(uname -s) = "Darwin" ]; then + build_arch="$(uname -sr | tr -d ' ' | tr '[:upper:]' '[:lower:]')-$(uname -m)" +else + build_arch="$(lsb_release -sc | tr -d ' ' | tr '[:upper:]' '[:lower:]')-$(uname -m)" +fi + +tarball=$build_dir/${commit_sha}-${build_arch}.tar + +# Create the tarball +tar cvf $tarball --mode="ugo=rx" bin/ + +# Compress it and copy it to the directory for the CI to upload it +gzip $tarball +mkdir -p "$BUILD_ARTIFACT_DIR"/gh-ost +cp ${tarball}.gz "$BUILD_ARTIFACT_DIR"/gh-ost/ + +### HACK HACK HACK HACK ### +# blame @carlosmn, @mattr and @timvaillancourt- +# Allow builds on buster to also be used for stretch +stretch_tarball_name=$(echo $(basename "${tarball}") | sed s/-buster-/-stretch-/) +cp ${tarball}.gz "$BUILD_ARTIFACT_DIR/gh-ost/${stretch_tarball_name}.gz" diff --git a/script/cibuild b/script/cibuild index 7e757b5..e609b7a 100755 --- a/script/cibuild +++ b/script/cibuild @@ -1,17 +1,3 @@ #!/bin/bash -set -e - -. script/bootstrap - -echo "Verifying code is formatted via 'gofmt -s -w go/'" -gofmt -s -w go/ -git diff --exit-code --quiet - -echo "Building" -script/build - -cd .gopath/src/github.com/github/gh-ost - -echo "Running unit tests" -go test ./go/... +script/test diff --git a/script/cibuild-gh-ost-build-deploy-tarball b/script/cibuild-gh-ost-build-deploy-tarball index 692b42b..a852ad6 100755 --- a/script/cibuild-gh-ost-build-deploy-tarball +++ b/script/cibuild-gh-ost-build-deploy-tarball @@ -1,37 +1,47 @@ -#!/bin/sh +#!/bin/bash -set -e +output_fold() { + # Exit early if no label provided + if [ -z "$1" ]; then + echo "output_fold(): requires a label argument." + return + fi -script/cibuild + exit_value=0 # exit_value is used to record exit status of the given command + label=$1 # human-readable label describing what's being folded up + shift 1 # having retrieved the output_fold()-specific arguments, strip them off $@ -# Get a fresh directory and make sure to delete it afterwards -build_dir=tmp/build -rm -rf $build_dir -mkdir -p $build_dir -trap "rm -rf $build_dir" EXIT + # Only echo the tags when in CI_MODE + if [ "$CI_MODE" ]; then + echo "%%%FOLD {$label}%%%" + fi -commit_sha=$(git rev-parse HEAD) + # run the remaining arguments. If the command exits non-0, the `||` will + # prevent the `-e` flag from seeing the failure exit code, and we'll see + # the second echo execute + "$@" || exit_value=$? -if [ $(uname -s) = "Darwin" ]; then - build_arch="$(uname -sr | tr -d ' ' | tr '[:upper:]' '[:lower:]')-$(uname -m)" -else - build_arch="$(lsb_release -sc | tr -d ' ' | tr '[:upper:]' '[:lower:]')-$(uname -m)" -fi + # Only echo the tags when in CI_MODE + if [ "$CI_MODE" ]; then + echo "%%%END FOLD%%%" + fi -tarball=$build_dir/${commit_sha}-${build_arch}.tar + # preserve the exit code from the subcommand. + return $exit_value +} -# Create the tarball -tar cvf $tarball --mode="ugo=rx" bin/ +function cleanup() { + echo + echo "%%%FOLD {Shutting down services...}%%%" + docker-compose down + echo "%%%END FOLD%%%" +} -# Compress it and copy it to the directory for the CI to upload it -gzip $tarball -mkdir -p "$BUILD_ARTIFACT_DIR"/gh-ost -cp ${tarball}.gz "$BUILD_ARTIFACT_DIR"/gh-ost/ +trap cleanup EXIT -### HACK HACK HACK ### -# Blame @carlosmn. In the good way. -# We don't have any jessie machines for building, but a pure-Go binary depends -# on a version of libc and ld which are widely available, so we can copy the -# tarball over with jessie in its name so we can deploy it on jessie machines. -jessie_tarball_name=$(echo $(basename "${tarball}") | sed s/-precise-/-jessie-/) -cp ${tarball}.gz "$BUILD_ARTIFACT_DIR/gh-ost/${jessie_tarball_name}.gz" +export CI_MODE=true + +output_fold "Bootstrapping container..." docker-compose build +output_fold "Running tests..." docker-compose run --rm app + +docker-compose run -e BUILD_ARTIFACT_DIR=$BUILD_ARTIFACT_DIR -v $BUILD_ARTIFACT_DIR:$BUILD_ARTIFACT_DIR app script/build-deploy-tarball diff --git a/script/cibuild-gh-ost-replica-tests b/script/cibuild-gh-ost-replica-tests new file mode 100755 index 0000000..3de9e05 --- /dev/null +++ b/script/cibuild-gh-ost-replica-tests @@ -0,0 +1,70 @@ +#!/bin/bash + +set -e + +whoami + +# Clone gh-ost-ci-env +# Only clone if not already running locally at latest commit +remote_commit=$(git ls-remote https://github.com/github/gh-ost-ci-env.git HEAD | cut -f1) +local_commit="unknown" +[ -d "gh-ost-ci-env" ] && local_commit=$(cd gh-ost-ci-env && git log --format="%H" -n 1) + +echo "remote commit is: $remote_commit" +echo "local commit is: $local_commit" + +if [ "$remote_commit" != "$local_commit" ] ; then + rm -rf ./gh-ost-ci-env + git clone https://github.com/github/gh-ost-ci-env.git +fi + +test_mysql_version() { + local mysql_version + mysql_version="$1" + + echo "##### Testing $mysql_version" + + echo "### Setting up sandbox for $mysql_version" + + find sandboxes -name "stop_all" | bash + + mkdir -p sandbox/binary + rm -rf sandbox/binary/* + gh-ost-ci-env/bin/linux/dbdeployer unpack gh-ost-ci-env/mysql-tarballs/"$mysql_version".tar.gz --unpack-version="$mysql_version" --sandbox-binary ${PWD}/sandbox/binary + + mkdir -p sandboxes + rm -rf sandboxes/* + + if echo "$mysql_version" | egrep "5[.]5[.]" ; then + gtid="" + else + gtid="--gtid" + fi + gh-ost-ci-env/bin/linux/dbdeployer deploy replication "$mysql_version" --nodes 2 --sandbox-binary ${PWD}/sandbox/binary --sandbox-home ${PWD}/sandboxes ${gtid} --my-cnf-options log_slave_updates --my-cnf-options log_bin --my-cnf-options binlog_format=ROW --sandbox-directory rsandbox + + sed '/sandboxes/d' -i gh-ost-ci-env/bin/gh-ost-test-mysql-master + echo 'sandboxes/rsandbox/m "$@"' >> gh-ost-ci-env/bin/gh-ost-test-mysql-master + + sed '/sandboxes/d' -i gh-ost-ci-env/bin/gh-ost-test-mysql-replica + echo 'sandboxes/rsandbox/s1 "$@"' >> gh-ost-ci-env/bin/gh-ost-test-mysql-replica + + export PATH="${PWD}/gh-ost-ci-env/bin/:${PATH}" + + gh-ost-test-mysql-master -uroot -e "create user 'gh-ost'@'%' identified by 'gh-ost'" + gh-ost-test-mysql-master -uroot -e "grant all on *.* to 'gh-ost'@'%'" + + echo "### Running gh-ost tests for $mysql_version" + ./localtests/test.sh -b bin/gh-ost + + find sandboxes -name "stop_all" | bash +} + +echo "Building..." +. script/build +# Test all versions: +find gh-ost-ci-env/mysql-tarballs/ -name "*.tar.gz" | while read f ; do basename $f ".tar.gz" ; done | sort -r | while read mysql_version ; do + echo "found MySQL version: $mysql_version" +done +find gh-ost-ci-env/mysql-tarballs/ -name "*.tar.gz" | while read f ; do basename $f ".tar.gz" ; done | sort -r | while read mysql_version ; do + test_mysql_version "$mysql_version" +done diff --git a/script/dock b/script/dock new file mode 100755 index 0000000..486061d --- /dev/null +++ b/script/dock @@ -0,0 +1,25 @@ +#!/bin/bash + +# Usage: +# dock [arg] +# dock test: build gh-ost & run unit and integration tests +# docker pkg [target-path]: build gh-ost release packages and copy to target path (default path: /tmp/gh-ost-release) + +command="$1" + +case "$command" in + "test") + docker_target="gh-ost-test" + docker build . -f Dockerfile.test -t "${docker_target}" && docker run --rm -it "${docker_target}:latest" + ;; + "pkg") + packages_path="${2:-/tmp/gh-ost-release}" + docker_target="gh-ost-packaging" + docker build . -f Dockerfile.packaging -t "${docker_target}" && docker run --rm -it -v "${packages_path}:/tmp/pkg" "${docker_target}:latest" bash -c 'find /tmp/gh-ost-release/ -maxdepth 1 -type f | xargs cp -t /tmp/pkg' + echo "packages generated on ${packages_path}:" + ls -l "${packages_path}" + ;; + *) + >&2 echo "Usage: dock dock [arg]" + exit 1 +esac diff --git a/script/ensure-go-installed b/script/ensure-go-installed index 21c49e6..baa5bd7 100755 --- a/script/ensure-go-installed +++ b/script/ensure-go-installed @@ -1,19 +1,20 @@ #!/bin/bash -GO_VERSION=go1.7 +PREFERRED_GO_VERSION=go1.14.7 +SUPPORTED_GO_VERSIONS='go1.1[456]' -GO_PKG_DARWIN=${GO_VERSION}.darwin-amd64.pkg -GO_PKG_DARWIN_SHA=e7089843bc7148ffcc147759985b213604d22bb9fd19bd930b515aa981bf1b22 +GO_PKG_DARWIN=${PREFERRED_GO_VERSION}.darwin-amd64.pkg +GO_PKG_DARWIN_SHA=0f215de06019a054a3da46a0722989986c956d719c7a0a8fc38a5f3c216d6f6b -GO_PKG_LINUX=${GO_VERSION}.linux-amd64.tar.gz -GO_PKG_LINUX_SHA=702ad90f705365227e902b42d91dd1a40e48ca7f67a2f4b2fd052aaa4295cd95 +GO_PKG_LINUX=${PREFERRED_GO_VERSION}.linux-amd64.tar.gz +GO_PKG_LINUX_SHA=4a7fa60f323ee1416a4b1425aefc37ea359e9d64df19c326a58953a97ad41ea5 export ROOTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" cd $ROOTDIR # If Go isn't installed globally, setup environment variables for local install. -if [ -z "$(which go)" ] || [ -z "$(go version | grep $GO_VERSION)" ]; then - GODIR="$ROOTDIR/.vendor/go17" +if [ -z "$(which go)" ] || [ -z "$(go version | grep "$SUPPORTED_GO_VERSIONS")" ]; then + GODIR="$ROOTDIR/.vendor/golocal" if [ $(uname -s) = "Darwin" ]; then export GOROOT="$GODIR/usr/local/go" @@ -25,24 +26,26 @@ if [ -z "$(which go)" ] || [ -z "$(go version | grep $GO_VERSION)" ]; then fi # Check if local install exists, and install otherwise. -if [ -z "$(which go)" ] || [ -z "$(go version | grep $GO_VERSION)" ]; then +if [ -z "$(which go)" ] || [ -z "$(go version | grep "$SUPPORTED_GO_VERSIONS")" ]; then [ -d "$GODIR" ] && rm -rf $GODIR mkdir -p "$GODIR" cd "$GODIR"; if [ $(uname -s) = "Darwin" ]; then - curl -L -O https://storage.googleapis.com/golang/$GO_PKG_DARWIN + curl -L -O https://dl.google.com/go/$GO_PKG_DARWIN shasum -a256 $GO_PKG_DARWIN | grep $GO_PKG_DARWIN_SHA xar -xf $GO_PKG_DARWIN cpio -i < com.googlecode.go.pkg/Payload else - curl -L -O https://storage.googleapis.com/golang/$GO_PKG_LINUX + curl -L -O https://dl.google.com/go/$GO_PKG_LINUX shasum -a256 $GO_PKG_LINUX | grep $GO_PKG_LINUX_SHA tar xf $GO_PKG_LINUX fi # Prove we did something right - echo "$GO_VERSION installed in $GODIR: Go Binary: $(which go)" + echo "$(go version) installed in $GODIR: Go Binary: $(which go)" +else + echo "$(go version) found in $GODIR: Go Binary: $(which go)" fi cd $ROOTDIR diff --git a/script/test b/script/test new file mode 100755 index 0000000..7e757b5 --- /dev/null +++ b/script/test @@ -0,0 +1,17 @@ +#!/bin/bash + +set -e + +. script/bootstrap + +echo "Verifying code is formatted via 'gofmt -s -w go/'" +gofmt -s -w go/ +git diff --exit-code --quiet + +echo "Building" +script/build + +cd .gopath/src/github.com/github/gh-ost + +echo "Running unit tests" +go test ./go/... diff --git a/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md b/vendor/github.com/go-sql-driver/mysql/.github/CONTRIBUTING.md similarity index 67% rename from vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md rename to vendor/github.com/go-sql-driver/mysql/.github/CONTRIBUTING.md index f87c198..8fe16bc 100644 --- a/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md +++ b/vendor/github.com/go-sql-driver/mysql/.github/CONTRIBUTING.md @@ -4,28 +4,11 @@ Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed). -Please provide the following minimum information: -* Your Go-MySQL-Driver version (or git SHA) -* Your Go version (run `go version` in your console) -* A detailed issue description -* Error Log if present -* If possible, a short example - - ## Contributing Code By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file. Don't forget to add yourself to the AUTHORS file. -### Pull Requests Checklist - -Please check the following points before submitting your pull request: -- [x] Code compiles correctly -- [x] Created tests, if possible -- [x] All tests pass -- [x] Extended the README / documentation, if necessary -- [x] Added yourself to the AUTHORS file - ### Code Review Everyone is invited to review and comment on pull requests. diff --git a/vendor/github.com/go-sql-driver/mysql/.github/ISSUE_TEMPLATE.md b/vendor/github.com/go-sql-driver/mysql/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..d9771f1 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,21 @@ +### Issue description +Tell us what should happen and what happens instead + +### Example code +```go +If possible, please enter some example code here to reproduce the issue. +``` + +### Error log +``` +If you have an error log, please paste it here. +``` + +### Configuration +*Driver version (or git SHA):* + +*Go version:* run `go version` in your console + +*Server version:* E.g. MySQL 5.6, MariaDB 10.0.20 + +*Server OS:* E.g. Debian 8.1 (Jessie), Windows 10 diff --git a/vendor/github.com/go-sql-driver/mysql/.github/PULL_REQUEST_TEMPLATE.md b/vendor/github.com/go-sql-driver/mysql/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..6f5c7eb --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,9 @@ +### Description +Please explain the changes you made here. + +### Checklist +- [ ] Code compiles correctly +- [ ] Created tests which fail without the change (if possible) +- [ ] All tests passing +- [ ] Extended the README / documentation, if necessary +- [ ] Added myself / the copyright holder to the AUTHORS file diff --git a/vendor/github.com/go-sql-driver/mysql/.gitignore b/vendor/github.com/go-sql-driver/mysql/.gitignore index ba8e0cb..2de28da 100644 --- a/vendor/github.com/go-sql-driver/mysql/.gitignore +++ b/vendor/github.com/go-sql-driver/mysql/.gitignore @@ -6,3 +6,4 @@ Icon? ehthumbs.db Thumbs.db +.idea diff --git a/vendor/github.com/go-sql-driver/mysql/.travis.yml b/vendor/github.com/go-sql-driver/mysql/.travis.yml index 2f4e3c2..56fcf25 100644 --- a/vendor/github.com/go-sql-driver/mysql/.travis.yml +++ b/vendor/github.com/go-sql-driver/mysql/.travis.yml @@ -1,10 +1,129 @@ sudo: false language: go go: - - 1.2 - - 1.3 - - 1.4 - - tip + - 1.10.x + - 1.11.x + - 1.12.x + - 1.13.x + - master + +before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls before_script: + - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf + - sudo service mysql restart + - .travis/wait_mysql.sh - mysql -e 'create database gotest;' + +matrix: + include: + - env: DB=MYSQL8 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mysql:8.0 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MYSQL57 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mysql:5.7 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MARIA55 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mariadb:5.5 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MARIA10_1 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mariadb:10.1 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - os: osx + osx_image: xcode10.1 + addons: + homebrew: + packages: + - mysql + update: true + go: 1.12.x + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + before_script: + - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB\nlocal_infile=1" >> /usr/local/etc/my.cnf + - mysql.server start + - mysql -uroot -e 'CREATE USER gotest IDENTIFIED BY "secret"' + - mysql -uroot -e 'GRANT ALL ON *.* TO gotest' + - mysql -uroot -e 'create database gotest;' + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3306 + - export MYSQL_TEST_CONCURRENT=1 + +script: + - go test -v -covermode=count -coverprofile=coverage.out + - go vet ./... + - .travis/gofmt.sh +after_script: + - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/vendor/github.com/go-sql-driver/mysql/.travis/docker.cnf b/vendor/github.com/go-sql-driver/mysql/.travis/docker.cnf new file mode 100644 index 0000000..e57754e --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/.travis/docker.cnf @@ -0,0 +1,5 @@ +[client] +user = gotest +password = secret +host = 127.0.0.1 +port = 3307 diff --git a/vendor/github.com/go-sql-driver/mysql/.travis/gofmt.sh b/vendor/github.com/go-sql-driver/mysql/.travis/gofmt.sh new file mode 100755 index 0000000..9bf0d16 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/.travis/gofmt.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -ev + +# Only check for go1.10+ since the gofmt style changed +if [[ $(go version) =~ go1\.([0-9]+) ]] && ((${BASH_REMATCH[1]} >= 10)); then + test -z "$(gofmt -d -s . | tee /dev/stderr)" +fi diff --git a/vendor/github.com/go-sql-driver/mysql/.travis/wait_mysql.sh b/vendor/github.com/go-sql-driver/mysql/.travis/wait_mysql.sh new file mode 100755 index 0000000..e87993e --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/.travis/wait_mysql.sh @@ -0,0 +1,8 @@ +#!/bin/sh +while : +do + if mysql -e 'select version()' 2>&1 | grep 'version()\|ERROR 2059 (HY000):'; then + break + fi + sleep 3 +done diff --git a/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/go-sql-driver/mysql/AUTHORS index 6dd0167..0896ba1 100644 --- a/vendor/github.com/go-sql-driver/mysql/AUTHORS +++ b/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -12,35 +12,95 @@ # Individual Persons Aaron Hopkins +Achille Roussel +Alex Snast +Alexey Palazhchenko +Andrew Reid Arne Hormann +Asta Xie +Bulat Gaifullin Carlos Nieto Chris Moos +Craig Wilson +Daniel Montoya +Daniel Nichter +Daniël van Eeden +Dave Protasowski DisposaBoy +Egor Smolyakov +Erwan Martin +Evan Shaw Frederick Mayle Gustavo Kristic +Hajime Nakagami Hanno Braun Henri Yandell Hirotaka Yamamoto +Huyiguang +ICHINOSE Shogo +Ilia Cimpoes INADA Naoki +Jacek Szwec James Harr +Jeff Hodges +Jeffrey Charles +Jerome Meyer +Jiajia Zhong Jian Zhen Joshua Prunier +Julien Lefevre Julien Schmidt +Justin Li +Justin Nuß Kamil Dziedzic +Kevin Malachowski +Kieron Woodhouse +Lennart Rudolph Leonardo YongUk Kim +Linh Tran Tuan +Lion Yang +Luca Looz Lucas Liu Luke Scott +Maciej Zimnoch Michael Woolnough +Nathanial Murphy Nicola Peduzzi +Olivier Mengué +oscarzhao +Paul Bonser +Peter Schultz +Rebecca Chin +Reed Allman +Richard Wilkes +Robert Russell Runrioter Wung +Shuode Li +Simon J Mudd Soroush Pour Stan Putrya +Stanley Gunawan +Steven Hartland +Thomas Wodarek +Tim Ruffles +Tom Jenkinson +Vladimir Kovpak +Xiangyu Hu Xiaobing Jiang Xiuming Chen -Julien Lefevre +Zhenye Xie # Organizations Barracuda Networks, Inc. +Counting Ltd. +DigitalOcean Inc. +Facebook Inc. +GitHub Inc. Google Inc. +InfoSum Ltd. +Keybase Inc. +Multiplay Ltd. +Percona LLC +Pivotal Inc. Stripe Inc. diff --git a/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md b/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md index 161ad0f..9cb97b3 100644 --- a/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md +++ b/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md @@ -1,21 +1,135 @@ -## HEAD +## Version 1.5 (2020-01-07) + +Changes: + + - Dropped support Go 1.9 and lower (#823, #829, #886, #1016, #1017) + - Improve buffer handling (#890) + - Document potentially insecure TLS configs (#901) + - Use a double-buffering scheme to prevent data races (#943) + - Pass uint64 values without converting them to string (#838, #955) + - Update collations and make utf8mb4 default (#877, #1054) + - Make NullTime compatible with sql.NullTime in Go 1.13+ (#995) + - Removed CloudSQL support (#993, #1007) + - Add Go Module support (#1003) + +New Features: + + - Implement support of optional TLS (#900) + - Check connection liveness (#934, #964, #997, #1048, #1051, #1052) + - Implement Connector Interface (#941, #958, #1020, #1035) + +Bugfixes: + + - Mark connections as bad on error during ping (#875) + - Mark connections as bad on error during dial (#867) + - Fix connection leak caused by rapid context cancellation (#1024) + - Mark connections as bad on error during Conn.Prepare (#1030) + + +## Version 1.4.1 (2018-11-14) + +Bugfixes: + + - Fix TIME format for binary columns (#818) + - Fix handling of empty auth plugin names (#835) + - Fix caching_sha2_password with empty password (#826) + - Fix canceled context broke mysqlConn (#862) + - Fix OldAuthSwitchRequest support (#870) + - Fix Auth Response packet for cleartext password (#887) + +## Version 1.4 (2018-06-03) + +Changes: + + - Documentation fixes (#530, #535, #567) + - Refactoring (#575, #579, #580, #581, #603, #615, #704) + - Cache column names (#444) + - Sort the DSN parameters in DSNs generated from a config (#637) + - Allow native password authentication by default (#644) + - Use the default port if it is missing in the DSN (#668) + - Removed the `strict` mode (#676) + - Do not query `max_allowed_packet` by default (#680) + - Dropped support Go 1.6 and lower (#696) + - Updated `ConvertValue()` to match the database/sql/driver implementation (#760) + - Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783) + - Improved the compatibility of the authentication system (#807) + +New Features: + + - Multi-Results support (#537) + - `rejectReadOnly` DSN option (#604) + - `context.Context` support (#608, #612, #627, #761) + - Transaction isolation level support (#619, #744) + - Read-Only transactions support (#618, #634) + - `NewConfig` function which initializes a config with default values (#679) + - Implemented the `ColumnType` interfaces (#667, #724) + - Support for custom string types in `ConvertValue` (#623) + - Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710) + - `caching_sha2_password` authentication plugin support (#794, #800, #801, #802) + - Implemented `driver.SessionResetter` (#779) + - `sha256_password` authentication plugin support (#808) + +Bugfixes: + + - Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718) + - Fixed LOAD LOCAL DATA INFILE for empty files (#590) + - Removed columns definition cache since it sometimes cached invalid data (#592) + - Don't mutate registered TLS configs (#600) + - Make RegisterTLSConfig concurrency-safe (#613) + - Handle missing auth data in the handshake packet correctly (#646) + - Do not retry queries when data was written to avoid data corruption (#302, #736) + - Cache the connection pointer for error handling before invalidating it (#678) + - Fixed imports for appengine/cloudsql (#700) + - Fix sending STMT_LONG_DATA for 0 byte data (#734) + - Set correct capacity for []bytes read from length-encoded strings (#766) + - Make RegisterDial concurrency-safe (#773) + + +## Version 1.3 (2016-12-01) Changes: - Go 1.1 is no longer supported - - Use decimals field from MySQL to format time types (#249) + - Use decimals fields in MySQL to format time types (#249) - Buffer optimizations (#269) - TLS ServerName defaults to the host (#283) + - Refactoring (#400, #410, #437) + - Adjusted documentation for second generation CloudSQL (#485) + - Documented DSN system var quoting rules (#502) + - Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512) + +New Features: + + - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) + - Support for returning table alias on Columns() (#289, #359, #382) + - Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490) + - Support for uint64 parameters with high bit set (#332, #345) + - Cleartext authentication plugin support (#327) + - Exported ParseDSN function and the Config struct (#403, #419, #429) + - Read / Write timeouts (#401) + - Support for JSON field type (#414) + - Support for multi-statements and multi-results (#411, #431) + - DSN parameter to set the driver-side max_allowed_packet value manually (#489) + - Native password authentication plugin support (#494, #524) Bugfixes: - - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) - Fixed handling of queries without columns and rows (#255) - Fixed a panic when SetKeepAlive() failed (#298) - -New Features: - - Support for returning table alias on Columns() (#289) - - Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318) + - Handle ERR packets while reading rows (#321) + - Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349) + - Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356) + - Actually zero out bytes in handshake response (#378) + - Fixed race condition in registering LOAD DATA INFILE handler (#383) + - Fixed tests with MySQL 5.7.9+ (#380) + - QueryUnescape TLS config names (#397) + - Fixed "broken pipe" error by writing to closed socket (#390) + - Fixed LOAD LOCAL DATA INFILE buffering (#424) + - Fixed parsing of floats into float64 when placeholders are used (#434) + - Fixed DSN tests with Go 1.7+ (#459) + - Handle ERR packets while waiting for EOF (#473) + - Invalidate connection on error while discarding additional results (#513) + - Allow terminating packets of length 0 (#516) ## Version 1.2 (2014-06-03) diff --git a/vendor/github.com/go-sql-driver/mysql/README.md b/vendor/github.com/go-sql-driver/mysql/README.md index 706b7ef..d2627a4 100644 --- a/vendor/github.com/go-sql-driver/mysql/README.md +++ b/vendor/github.com/go-sql-driver/mysql/README.md @@ -1,13 +1,9 @@ # Go-MySQL-Driver -A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) package +A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package ![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin") -**Latest stable Release:** [Version 1.2 (June 03, 2014)](https://github.com/go-sql-driver/mysql/releases) - -[![Build Status](https://travis-ci.org/go-sql-driver/mysql.png?branch=master)](https://travis-ci.org/go-sql-driver/mysql) - --------------------------------------- * [Features](#features) * [Requirements](#requirements) @@ -19,6 +15,9 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa * [Address](#address) * [Parameters](#parameters) * [Examples](#examples) + * [Connection pool and timeouts](#connection-pool-and-timeouts) + * [context.Context Support](#contextcontext-support) + * [ColumnType Support](#columntype-support) * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) * [time.Time support](#timetime-support) * [Unicode support](#unicode-support) @@ -30,31 +29,31 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa ## Features * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance") * Native Go implementation. No C-bindings, just pure Go - * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](http://godoc.org/github.com/go-sql-driver/mysql#DialFunc) + * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc) * Automatic handling of broken connections * Automatic Connection Pooling *(by database/sql package)* * Supports queries larger than 16MB - * Full [`sql.RawBytes`](http://golang.org/pkg/database/sql/#RawBytes) support. + * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. * Intelligent `LONG DATA` handling in prepared statements * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation ## Requirements - * Go 1.2 or higher + * Go 1.10 or higher. We aim to support the 3 latest versions of Go. * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- ## Installation -Simple install the package to your [$GOPATH](http://code.google.com/p/go-wiki/wiki/GOPATH "GOPATH") with the [go tool](http://golang.org/cmd/go/ "go command") from shell: +Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: ```bash -$ go get github.com/go-sql-driver/mysql +$ go get -u github.com/go-sql-driver/mysql ``` -Make sure [Git is installed](http://git-scm.com/downloads) on your machine and in your system's `PATH`. +Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. ## Usage -_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](http://golang.org/pkg/database/sql) API then. +_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: ```go @@ -93,17 +92,20 @@ This has the same effect as an empty DSN string: ``` +Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. + #### Password Passwords can consist of any character. Escaping is **not** necessary. #### Protocol -See [net.Dial](http://golang.org/pkg/net/#Dial) for more information which networks are available. +See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. In general you should use an Unix domain socket if available and TCP otherwise for best performance. #### Address -For TCP and UDP networks, addresses have the form `host:port`. +For TCP and UDP networks, addresses have the form `host[:port]`. +If `port` is omitted, the default port will be used. If `host` is a literal IPv6 address, it must be enclosed in square brackets. -The functions [net.JoinHostPort](http://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](http://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form. +The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](https://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form. For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`. @@ -133,6 +135,15 @@ Default: false `allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. +##### `allowNativePasswords` + +``` +Type: bool +Valid Values: true, false +Default: true +``` +`allowNativePasswords=false` disallows the usage of MySQL native password method. + ##### `allowOldPasswords` ``` @@ -155,18 +166,34 @@ Sets the charset used for client-server interaction (`"SET NAMES "`). If Usage of the `charset` parameter is discouraged because it issues additional queries to the server. Unless you need the fallback behavior, please use `collation` instead. +##### `checkConnLiveness` + +``` +Type: bool +Valid Values: true, false +Default: true +``` + +On supported platforms connections retrieved from the connection pool are checked for liveness before using them. If the check fails, the respective connection is marked as bad and the query retried with another connection. +`checkConnLiveness=false` disables this liveness check of connections. + ##### `collation` ``` Type: string Valid Values: -Default: utf8_general_ci +Default: utf8mb4_general_ci ``` Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail. A list of valid charsets for a server is retrievable with `SHOW COLLATION`. +The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. + +Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). + + ##### `clientFoundRows` ``` @@ -213,12 +240,31 @@ Valid Values: Default: UTC ``` -Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. +Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter. -Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. +Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. +##### `maxAllowedPacket` +``` +Type: decimal number +Default: 4194304 +``` + +Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. + +##### `multiStatements` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. + +When `multiStatements` is used, `?` parameters must only be used in the first statement. ##### `parseTime` @@ -229,9 +275,19 @@ Default: false ``` `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` +The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`. -##### `strict` +##### `readTimeout` + +``` +Type: duration +Default: 0 +``` + +I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + +##### `rejectReadOnly` ``` Type: bool @@ -239,41 +295,89 @@ Valid Values: true, false Default: false ``` -`strict=true` enables the strict mode in which MySQL warnings are treated as errors. -By default MySQL also treats notes as warnings. Use [`sql_notes=false`](http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_sql_notes) to ignore notes. See the [examples](#examples) for an DSN example. +`rejectReadOnly=true` causes the driver to reject read-only connections. This +is for a possible race condition during an automatic failover, where the mysql +client gets connected to a read-only replica after the failover. + +Note that this should be a fairly rare case, as an automatic failover normally +happens when the primary is down, and the race condition shouldn't happen +unless it comes back up online as soon as the failover is kicked off. On the +other hand, when this happens, a MySQL application can get stuck on a +read-only connection until restarted. It is however fairly easy to reproduce, +for example, using a manual failover on AWS Aurora's MySQL-compatible cluster. + +If you are not relying on read-only transactions to reject writes that aren't +supposed to happen, setting this on some MySQL providers (such as AWS Aurora) +is safer for failovers. + +Note that ERROR 1290 can be returned for a `read-only` server and this option will +cause a retry for that error. However the same error number is used for some +other cases. You should ensure your application will never cause an ERROR 1290 +except for `read-only` mode when enabling this option. + + +##### `serverPubKey` + +``` +Type: string +Valid Values: +Default: none +``` + +Server public keys can be registered with [`mysql.RegisterServerPubKey`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterServerPubKey), which can then be used by the assigned name in the DSN. +Public keys are used to transmit encrypted data, e.g. for authentication. +If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required. ##### `timeout` ``` -Type: decimal number +Type: duration Default: OS default ``` -*Driver* side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout). +Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. ##### `tls` ``` Type: bool / string -Valid Values: true, false, skip-verify, +Valid Values: true, false, skip-verify, preferred, Default: false ``` -`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](http://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). +`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). + + +##### `writeTimeout` + +``` +Type: duration +Default: 0 +``` + +I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. ##### System Variables -All other parameters are interpreted as system variables: - * `autocommit`: `"SET autocommit="` - * [`time_zone`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `"SET time_zone="` - * [`tx_isolation`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `"SET tx_isolation="` - * `param`: `"SET ="` +Any other parameters are interpreted as system variables: + * `=`: `SET =` + * `=`: `SET =` + * `=%27%27`: `SET =''` + +Rules: +* The values for string variables must be quoted with `'`. +* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! + (which implies values of string variables must be wrapped with `%27`). + +Examples: + * `autocommit=1`: `SET autocommit=1` + * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` + * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` -*The values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!* #### Examples ``` @@ -288,9 +392,9 @@ root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true ``` -Use the [strict mode](#strict) but ignore notes: +Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html): ``` -user:password@/dbname?strict=true&sql_notes=false +user:password@/dbname?sql_mode=TRADITIONAL ``` TCP via IPv6: @@ -305,7 +409,7 @@ id:password@tcp(your-amazonaws-uri.com:3306)/dbname Google Cloud SQL on App Engine: ``` -user@cloudsql(project-id:instance-name)/dbname +user:password@unix(/cloudsql/project-id:region-name:instance-name)/dbname ``` TCP using default port (3306) on localhost: @@ -323,6 +427,18 @@ No Database preselected: user:password@/ ``` + +### Connection pool and timeouts +The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. + +## `ColumnType` Support +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. + +## `context.Context` Support +Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. +See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. + + ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): ```go @@ -333,28 +449,27 @@ Files must be whitelisted by registering them with `mysql.RegisterLocalFile(file To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. -See the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. +See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. ### `time.Time` support -The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm. +The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. -However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter. +However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). -Alternatively you can use the [`NullTime`](http://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. +Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. ### Unicode support -Since version 1.1 Go-MySQL-Driver automatically uses the collation `utf8_general_ci` by default. +Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. -See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. - +See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. @@ -374,13 +489,13 @@ Mozilla summarizes the license scope as follows: That means: - * You can **use** the **unchanged** source code both in private and commercially - * When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0) - * You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged** + * You can **use** the **unchanged** source code both in private and commercially. + * When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). + * You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. -Please read the [MPL 2.0 FAQ](http://www.mozilla.org/MPL/2.0/FAQ.html) if you have further questions regarding the license. +Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license. -You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE) +You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). ![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") diff --git a/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/go-sql-driver/mysql/auth.go new file mode 100644 index 0000000..fec7040 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/auth.go @@ -0,0 +1,422 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "sync" +) + +// server pub keys registry +var ( + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey +) + +// RegisterServerPubKey registers a server RSA public key which can be used to +// send data in a secure manner to the server without receiving the public key +// in a potentially insecure way from the server first. +// Registered keys can afterwards be used adding serverPubKey= to the DSN. +// +// Note: The provided rsa.PublicKey instance is exclusively owned by the driver +// after registering it and may not be modified. +// +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } +// +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } +// +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } +// +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } +// +func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry == nil { + serverPubKeyRegistry = make(map[string]*rsa.PublicKey) + } + + serverPubKeyRegistry[name] = pubKey + serverPubKeyLock.Unlock() +} + +// DeregisterServerPubKey removes the public key registered with the given name. +func DeregisterServerPubKey(name string) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry != nil { + delete(serverPubKeyRegistry, name) + } + serverPubKeyLock.Unlock() +} + +func getServerPubKey(name string) (pubKey *rsa.PublicKey) { + serverPubKeyLock.RLock() + if v, ok := serverPubKeyRegistry[name]; ok { + pubKey = v + } + serverPubKeyLock.RUnlock() + return +} + +// Hash password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +// Pseudo random number generator +func newMyRnd(seed1, seed2 uint32) *myRnd { + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } +} + +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) +} + +// Generate binary hash from byte string using insecure pre 4.1 method +func pwHash(password []byte) (result [2]uint32) { + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 + + for _, c := range password { + // skip spaces and tabs in password + if c == ' ' || c == '\t' { + continue + } + + tmp = uint32(c) + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] + add += tmp + } + + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + + return +} + +// Hash password using insecure pre 4.1 method +func scrambleOldPassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + scramble = scramble[:8] + + hashPw := pwHash([]byte(password)) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} + +// Hash password using 4.1+ method (SHA1) +func scramblePassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write([]byte(password)) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} + +func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { + enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) + if err != nil { + return err + } + return mc.writeAuthSwitchPacket(enc) +} + +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { + switch plugin { + case "caching_sha2_password": + authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + return authResp, nil + + case "mysql_old_password": + if !mc.cfg.AllowOldPasswords { + return nil, ErrOldPassword + } + // Note: there are edge cases where this should work but doesn't; + // this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil + + case "mysql_clear_password": + if !mc.cfg.AllowCleartextPasswords { + return nil, ErrCleartextPassword + } + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + return append([]byte(mc.cfg.Passwd), 0), nil + + case "mysql_native_password": + if !mc.cfg.AllowNativePasswords { + return nil, ErrNativePassword + } + // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // Native password authentication only need and will need 20-byte challenge. + authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + return authResp, nil + + case "sha256_password": + if len(mc.cfg.Passwd) == 0 { + return []byte{0}, nil + } + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + return append([]byte(mc.cfg.Passwd), 0), nil + } + + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + return []byte{1}, nil + } + + // encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + return enc, err + + default: + errLog.Print("unknown auth plugin:", plugin) + return nil, ErrUnknownPlugin + } +} + +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { + // Read Result Packet + authData, newPlugin, err := mc.readAuthResult() + if err != nil { + return err + } + + // handle auth plugin switch, if requested + if newPlugin != "" { + // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is + // sent and we have to keep using the cipher sent in the init packet. + if authData == nil { + authData = oldAuthData + } else { + // copy data from read buffer to owned slice + copy(oldAuthData, authData) + } + + plugin = newPlugin + + authResp, err := mc.auth(authData, plugin) + if err != nil { + return err + } + if err = mc.writeAuthSwitchPacket(authResp); err != nil { + return err + } + + // Read Result Packet + authData, newPlugin, err = mc.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if newPlugin != "" { + return ErrMalformPkt + } + } + + switch plugin { + + // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + case "caching_sha2_password": + switch len(authData) { + case 0: + return nil // auth successful + case 1: + switch authData[0] { + case cachingSha2PasswordFastAuthSuccess: + if err = mc.readResultOK(); err == nil { + return nil // auth successful + } + + case cachingSha2PasswordPerformFullAuthentication: + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + if err != nil { + return err + } + } else { + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } + data[4] = cachingSha2PasswordRequestPublicKey + mc.writePacket(data) + + // parse public key + if data, err = mc.readPacket(); err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + pubKey = pkix.(*rsa.PublicKey) + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pubKey) + if err != nil { + return err + } + } + return mc.readResultOK() + + default: + return ErrMalformPkt + } + default: + return ErrMalformPkt + } + + case "sha256_password": + switch len(authData) { + case 0: + return nil // auth successful + default: + block, _ := pem.Decode(authData) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + return mc.readResultOK() + } + + default: + return nil // auth successful + } + + return err +} diff --git a/vendor/github.com/go-sql-driver/mysql/auth_test.go b/vendor/github.com/go-sql-driver/mysql/auth_test.go new file mode 100644 index 0000000..1920ef3 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/auth_test.go @@ -0,0 +1,1330 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "testing" +) + +var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" + + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAol0Z8G8U+25Btxk/g/fm\n" + + "UAW/wEKjQCTjkibDE4B+qkuWeiumg6miIRhtilU6m9BFmLQSy1ltYQuu4k17A4tQ\n" + + "rIPpOQYZges/qsDFkZh3wyK5jL5WEFVdOasf6wsfszExnPmcZS4axxoYJfiuilrN\n" + + "hnwinBAqfi3S0sw5MpSI4Zl1AbOrHG4zDI62Gti2PKiMGyYDZTS9xPrBLbN95Kby\n" + + "FFclQLEzA9RJcS1nHFsWtRgHjGPhhjCQxEm9NQ1nePFhCfBfApyfH1VM2VCOQum6\n" + + "Ci9bMuHWjTjckC84mzF99kOxOWVU7mwS6gnJqBzpuz8t3zq8/iQ2y7QrmZV+jTJP\n" + + "WQIDAQAB\n" + + "-----END PUBLIC KEY-----\n") + +var testPubKeyRSA *rsa.PublicKey + +func init() { + block, _ := pem.Decode(testPubKey) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + testPubKeyRSA = pub.(*rsa.PublicKey) +} + +func TestScrambleOldPass(t *testing.T) { + scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2} + vectors := []struct { + pass string + out string + }{ + {" pass", "47575c5a435b4251"}, + {"pass ", "47575c5a435b4251"}, + {"123\t456", "575c47505b5b5559"}, + {"C0mpl!ca ted#PASS123", "5d5d554849584a45"}, + } + for _, tuple := range vectors { + ours := scrambleOldPassword(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed old password %q", tuple.pass) + } + } +} + +func TestScrambleSHA256Pass(t *testing.T) { + scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21} + vectors := []struct { + pass string + out string + }{ + {"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"}, + {"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"}, + } + for _, tuple := range vectors { + ours := scrambleSHA256Password(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed SHA256 password %q", tuple.pass) + } + } +} + +func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56, + 139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15, + 184, 150, 26, 61, 57, 235} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 3, // Fast Auth Success + 7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // pub key response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthFastCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowNativePasswords = false + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthFastNativePassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252, + 172, 50, 211, 192, 240, 164, 26, 48, 207, 45} + if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{1} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + // hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + + // unset TLS config to prevent the actual establishment of a TLS wrapper + mc.cfg.tls = nil + + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 6, 1}, testPubKey...), + + // OK + {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 4 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Pub Key Request + 1, 0, 0, 5, 2, + + // 3. Packet: Encrypted Password + 0, 1, 0, 7, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {2, 0, 0, 4, 1, 4}, // Perform Full Authentication + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Cleartext password + 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + conn.maxReads = 1 + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthSwitchCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = false + + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + conn.maxReads = 1 + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthSwitchNativePassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103, + 21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +func TestOldAuthSwitchNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +func TestAuthSwitchOldPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +func TestOldAuthSwitch(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} +func TestAuthSwitchOldPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +func TestOldAuthSwitchPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // OldAuthSwitch request. + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Empty Password + 1, 0, 0, 3, 0, + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Pub Key Request + 1, 0, 0, 3, 1, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Encrypted Password + 0, 1, 0, 3, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Cleartext Password + 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/benchmark_test.go b/vendor/github.com/go-sql-driver/mysql/benchmark_test.go new file mode 100644 index 0000000..3e25a3b --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/benchmark_test.go @@ -0,0 +1,373 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "fmt" + "math" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type TB testing.B + +func (tb *TB) check(err error) { + if err != nil { + tb.Fatal(err) + } +} + +func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB { + tb.check(err) + return db +} + +func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows { + tb.check(err) + return rows +} + +func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { + tb.check(err) + return stmt +} + +func initDB(b *testing.B, queries ...string) *sql.DB { + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + for _, query := range queries { + if _, err := db.Exec(query); err != nil { + b.Fatalf("error on %q: %v", query, err) + } + } + return db +} + +const concurrencyLevel = 10 + +func BenchmarkQuery(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + var got string + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Errorf("query = %q; want one", got) + wg.Done() + return + } + } + }() + } +} + +func BenchmarkExec(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := tb.checkDB(sql.Open("mysql", dsn)) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("DO 1")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + if _, err := stmt.Exec(); err != nil { + b.Fatal(err.Error()) + } + } + }() + } +} + +// data, but no db writes +var roundtripSample []byte + +func initRoundtripBenchmarks() ([]byte, int, int) { + if roundtripSample == nil { + roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024)) + } + return roundtripSample, 16, len(roundtripSample) +} + +func BenchmarkRoundtripTxt(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + sampleString := string(sample) + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + b.StartTimer() + var result string + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sampleString[0:length] + rows := tb.checkRows(db.Query(`SELECT "` + test + `"`)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if result != test { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkRoundtripBin(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + stmt := tb.checkStmt(db.Prepare("SELECT ?")) + defer stmt.Close() + b.StartTimer() + var result sql.RawBytes + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sample[0:length] + rows := tb.checkRows(stmt.Query(test)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if !bytes.Equal(result, test) { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkInterpolation(b *testing.B) { + mc := &mysqlConn{ + cfg: &Config{ + InterpolateParams: true, + Loc: time.UTC, + }, + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + buf: newBuffer(nil), + } + + args := []driver.Value{ + int64(42424242), + float64(math.Pi), + false, + time.Unix(1423411542, 807015000), + []byte("bytes containing special chars ' \" \a \x00"), + "string containing special chars ' \" \a \x00", + } + q := "SELECT ?, ?, ?, ?, ?, ?" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := mc.interpolateParams(q, args) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var got string + for pb.Next() { + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Fatalf("query = %q; want one", got) + } + } + }) +} + +func BenchmarkQueryContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := stmt.ExecContext(ctx); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkExecContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes. +// "size=" means size of each blobs. +func BenchmarkQueryRawBytes(b *testing.B) { + var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} + db := initDB(b, + "DROP TABLE IF EXISTS bench_rawbytes", + "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", + ) + defer db.Close() + + blob := make([]byte, sizes[len(sizes)-1]) + for i := range blob { + blob[i] = 42 + } + for i := 0; i < 100; i++ { + _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) + if err != nil { + b.Fatal(err) + } + } + + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + b.ReportAllocs() + b.ResetTimer() + + for j := 0; j < b.N; j++ { + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + if err != nil { + b.Fatal(err) + } + nrows := 0 + for rows.Next() { + var buf sql.RawBytes + err := rows.Scan(&buf) + if err != nil { + b.Fatal(err) + } + if len(buf) != s { + b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + } + nrows++ + } + rows.Close() + if nrows != 100 { + b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) + } + } + }) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/go-sql-driver/mysql/buffer.go index 509ce89..0774c5c 100644 --- a/vendor/github.com/go-sql-driver/mysql/buffer.go +++ b/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -8,53 +8,86 @@ package mysql -import "io" +import ( + "io" + "net" + "time" +) const defaultBufSize = 4096 +const maxCachedBufSize = 256 * 1024 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. +// This buffer is backed by two byte slices in a double-buffering scheme type buffer struct { - buf []byte - rd io.Reader - idx int - length int + buf []byte // buf is a byte buffer who's length and capacity are equal. + nc net.Conn + idx int + length int + timeout time.Duration + dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer + flipcnt uint // flipccnt is the current buffer counter for double-buffering } -func newBuffer(rd io.Reader) buffer { - var b [defaultBufSize]byte +// newBuffer allocates and returns a new buffer. +func newBuffer(nc net.Conn) buffer { + fg := make([]byte, defaultBufSize) return buffer{ - buf: b[:], - rd: rd, + buf: fg, + nc: nc, + dbuf: [2][]byte{fg, nil}, } } +// flip replaces the active buffer with the background buffer +// this is a delayed flip that simply increases the buffer counter; +// the actual flip will be performed the next time we call `buffer.fill` +func (b *buffer) flip() { + b.flipcnt += 1 +} + // fill reads into the buffer until at least _need_ bytes are in it func (b *buffer) fill(need int) error { n := b.length + // fill data into its double-buffering target: if we've called + // flip on this buffer, we'll be copying to the background buffer, + // and then filling it with network data; otherwise we'll just move + // the contents of the current buffer to the front before filling it + dest := b.dbuf[b.flipcnt&1] - // move existing data to the beginning - if n > 0 && b.idx > 0 { - copy(b.buf[0:n], b.buf[b.idx:]) - } - - // grow buffer if necessary - // TODO: let the buffer shrink again at some point - // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { + // grow buffer if necessary to fit the whole packet. + if need > len(dest) { // Round up to the next multiple of the default size - newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - copy(newBuf, b.buf) - b.buf = newBuf + dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + + // if the allocated buffer is not too large, move it to backing storage + // to prevent extra allocations on applications that perform large reads + if len(dest) <= maxCachedBufSize { + b.dbuf[b.flipcnt&1] = dest + } } + // if we're filling the fg buffer, move the existing data to the start of it. + // if we're filling the bg buffer, copy over the data + if n > 0 { + copy(dest[:n], b.buf[b.idx:]) + } + + b.buf = dest b.idx = 0 for { - nn, err := b.rd.Read(b.buf[n:]) + if b.timeout > 0 { + if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + return err + } + } + + nn, err := b.nc.Read(b.buf[n:]) n += nn switch err { @@ -94,43 +127,56 @@ func (b *buffer) readNext(need int) ([]byte, error) { return b.buf[offset:b.idx], nil } -// returns a buffer with the requested size. +// takeBuffer returns a buffer with the requested size. // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] + if length <= cap(b.buf) { + return b.buf[:length], nil } if length < maxPacketSize { b.buf = make([]byte, length) - return b.buf + return b.buf, nil } - return make([]byte, length) + + // buffer is larger than we want to store. + return make([]byte, length), nil } -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { - return b.buf[:length] +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer } - return nil + return b.buf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { - if b.length == 0 { - return b.buf +func (b *buffer) takeCompleteBuffer() ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { + if b.length > 0 { + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] } return nil } diff --git a/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/go-sql-driver/mysql/collations.go index 6c1d613..8d2b556 100644 --- a/vendor/github.com/go-sql-driver/mysql/collations.go +++ b/vendor/github.com/go-sql-driver/mysql/collations.go @@ -8,182 +8,190 @@ package mysql -const defaultCollation byte = 33 // utf8_general_ci +const defaultCollation = "utf8mb4_general_ci" +const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: -// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID +// +// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. +// +// ucs2, utf16, and utf32 can't be used for connection charset. +// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset +// They are commented out to reduce this map. var collations = map[string]byte{ - "big5_chinese_ci": 1, - "latin2_czech_cs": 2, - "dec8_swedish_ci": 3, - "cp850_general_ci": 4, - "latin1_german1_ci": 5, - "hp8_english_ci": 6, - "koi8r_general_ci": 7, - "latin1_swedish_ci": 8, - "latin2_general_ci": 9, - "swe7_swedish_ci": 10, - "ascii_general_ci": 11, - "ujis_japanese_ci": 12, - "sjis_japanese_ci": 13, - "cp1251_bulgarian_ci": 14, - "latin1_danish_ci": 15, - "hebrew_general_ci": 16, - "tis620_thai_ci": 18, - "euckr_korean_ci": 19, - "latin7_estonian_cs": 20, - "latin2_hungarian_ci": 21, - "koi8u_general_ci": 22, - "cp1251_ukrainian_ci": 23, - "gb2312_chinese_ci": 24, - "greek_general_ci": 25, - "cp1250_general_ci": 26, - "latin2_croatian_ci": 27, - "gbk_chinese_ci": 28, - "cp1257_lithuanian_ci": 29, - "latin5_turkish_ci": 30, - "latin1_german2_ci": 31, - "armscii8_general_ci": 32, - "utf8_general_ci": 33, - "cp1250_czech_cs": 34, - "ucs2_general_ci": 35, - "cp866_general_ci": 36, - "keybcs2_general_ci": 37, - "macce_general_ci": 38, - "macroman_general_ci": 39, - "cp852_general_ci": 40, - "latin7_general_ci": 41, - "latin7_general_cs": 42, - "macce_bin": 43, - "cp1250_croatian_ci": 44, - "utf8mb4_general_ci": 45, - "utf8mb4_bin": 46, - "latin1_bin": 47, - "latin1_general_ci": 48, - "latin1_general_cs": 49, - "cp1251_bin": 50, - "cp1251_general_ci": 51, - "cp1251_general_cs": 52, - "macroman_bin": 53, - "utf16_general_ci": 54, - "utf16_bin": 55, - "utf16le_general_ci": 56, - "cp1256_general_ci": 57, - "cp1257_bin": 58, - "cp1257_general_ci": 59, - "utf32_general_ci": 60, - "utf32_bin": 61, - "utf16le_bin": 62, - "binary": 63, - "armscii8_bin": 64, - "ascii_bin": 65, - "cp1250_bin": 66, - "cp1256_bin": 67, - "cp866_bin": 68, - "dec8_bin": 69, - "greek_bin": 70, - "hebrew_bin": 71, - "hp8_bin": 72, - "keybcs2_bin": 73, - "koi8r_bin": 74, - "koi8u_bin": 75, - "latin2_bin": 77, - "latin5_bin": 78, - "latin7_bin": 79, - "cp850_bin": 80, - "cp852_bin": 81, - "swe7_bin": 82, - "utf8_bin": 83, - "big5_bin": 84, - "euckr_bin": 85, - "gb2312_bin": 86, - "gbk_bin": 87, - "sjis_bin": 88, - "tis620_bin": 89, - "ucs2_bin": 90, - "ujis_bin": 91, - "geostd8_general_ci": 92, - "geostd8_bin": 93, - "latin1_spanish_ci": 94, - "cp932_japanese_ci": 95, - "cp932_bin": 96, - "eucjpms_japanese_ci": 97, - "eucjpms_bin": 98, - "cp1250_polish_ci": 99, - "utf16_unicode_ci": 101, - "utf16_icelandic_ci": 102, - "utf16_latvian_ci": 103, - "utf16_romanian_ci": 104, - "utf16_slovenian_ci": 105, - "utf16_polish_ci": 106, - "utf16_estonian_ci": 107, - "utf16_spanish_ci": 108, - "utf16_swedish_ci": 109, - "utf16_turkish_ci": 110, - "utf16_czech_ci": 111, - "utf16_danish_ci": 112, - "utf16_lithuanian_ci": 113, - "utf16_slovak_ci": 114, - "utf16_spanish2_ci": 115, - "utf16_roman_ci": 116, - "utf16_persian_ci": 117, - "utf16_esperanto_ci": 118, - "utf16_hungarian_ci": 119, - "utf16_sinhala_ci": 120, - "utf16_german2_ci": 121, - "utf16_croatian_ci": 122, - "utf16_unicode_520_ci": 123, - "utf16_vietnamese_ci": 124, - "ucs2_unicode_ci": 128, - "ucs2_icelandic_ci": 129, - "ucs2_latvian_ci": 130, - "ucs2_romanian_ci": 131, - "ucs2_slovenian_ci": 132, - "ucs2_polish_ci": 133, - "ucs2_estonian_ci": 134, - "ucs2_spanish_ci": 135, - "ucs2_swedish_ci": 136, - "ucs2_turkish_ci": 137, - "ucs2_czech_ci": 138, - "ucs2_danish_ci": 139, - "ucs2_lithuanian_ci": 140, - "ucs2_slovak_ci": 141, - "ucs2_spanish2_ci": 142, - "ucs2_roman_ci": 143, - "ucs2_persian_ci": 144, - "ucs2_esperanto_ci": 145, - "ucs2_hungarian_ci": 146, - "ucs2_sinhala_ci": 147, - "ucs2_german2_ci": 148, - "ucs2_croatian_ci": 149, - "ucs2_unicode_520_ci": 150, - "ucs2_vietnamese_ci": 151, - "ucs2_general_mysql500_ci": 159, - "utf32_unicode_ci": 160, - "utf32_icelandic_ci": 161, - "utf32_latvian_ci": 162, - "utf32_romanian_ci": 163, - "utf32_slovenian_ci": 164, - "utf32_polish_ci": 165, - "utf32_estonian_ci": 166, - "utf32_spanish_ci": 167, - "utf32_swedish_ci": 168, - "utf32_turkish_ci": 169, - "utf32_czech_ci": 170, - "utf32_danish_ci": 171, - "utf32_lithuanian_ci": 172, - "utf32_slovak_ci": 173, - "utf32_spanish2_ci": 174, - "utf32_roman_ci": 175, - "utf32_persian_ci": 176, - "utf32_esperanto_ci": 177, - "utf32_hungarian_ci": 178, - "utf32_sinhala_ci": 179, - "utf32_german2_ci": 180, - "utf32_croatian_ci": 181, - "utf32_unicode_520_ci": 182, - "utf32_vietnamese_ci": 183, + "big5_chinese_ci": 1, + "latin2_czech_cs": 2, + "dec8_swedish_ci": 3, + "cp850_general_ci": 4, + "latin1_german1_ci": 5, + "hp8_english_ci": 6, + "koi8r_general_ci": 7, + "latin1_swedish_ci": 8, + "latin2_general_ci": 9, + "swe7_swedish_ci": 10, + "ascii_general_ci": 11, + "ujis_japanese_ci": 12, + "sjis_japanese_ci": 13, + "cp1251_bulgarian_ci": 14, + "latin1_danish_ci": 15, + "hebrew_general_ci": 16, + "tis620_thai_ci": 18, + "euckr_korean_ci": 19, + "latin7_estonian_cs": 20, + "latin2_hungarian_ci": 21, + "koi8u_general_ci": 22, + "cp1251_ukrainian_ci": 23, + "gb2312_chinese_ci": 24, + "greek_general_ci": 25, + "cp1250_general_ci": 26, + "latin2_croatian_ci": 27, + "gbk_chinese_ci": 28, + "cp1257_lithuanian_ci": 29, + "latin5_turkish_ci": 30, + "latin1_german2_ci": 31, + "armscii8_general_ci": 32, + "utf8_general_ci": 33, + "cp1250_czech_cs": 34, + //"ucs2_general_ci": 35, + "cp866_general_ci": 36, + "keybcs2_general_ci": 37, + "macce_general_ci": 38, + "macroman_general_ci": 39, + "cp852_general_ci": 40, + "latin7_general_ci": 41, + "latin7_general_cs": 42, + "macce_bin": 43, + "cp1250_croatian_ci": 44, + "utf8mb4_general_ci": 45, + "utf8mb4_bin": 46, + "latin1_bin": 47, + "latin1_general_ci": 48, + "latin1_general_cs": 49, + "cp1251_bin": 50, + "cp1251_general_ci": 51, + "cp1251_general_cs": 52, + "macroman_bin": 53, + //"utf16_general_ci": 54, + //"utf16_bin": 55, + //"utf16le_general_ci": 56, + "cp1256_general_ci": 57, + "cp1257_bin": 58, + "cp1257_general_ci": 59, + //"utf32_general_ci": 60, + //"utf32_bin": 61, + //"utf16le_bin": 62, + "binary": 63, + "armscii8_bin": 64, + "ascii_bin": 65, + "cp1250_bin": 66, + "cp1256_bin": 67, + "cp866_bin": 68, + "dec8_bin": 69, + "greek_bin": 70, + "hebrew_bin": 71, + "hp8_bin": 72, + "keybcs2_bin": 73, + "koi8r_bin": 74, + "koi8u_bin": 75, + "utf8_tolower_ci": 76, + "latin2_bin": 77, + "latin5_bin": 78, + "latin7_bin": 79, + "cp850_bin": 80, + "cp852_bin": 81, + "swe7_bin": 82, + "utf8_bin": 83, + "big5_bin": 84, + "euckr_bin": 85, + "gb2312_bin": 86, + "gbk_bin": 87, + "sjis_bin": 88, + "tis620_bin": 89, + //"ucs2_bin": 90, + "ujis_bin": 91, + "geostd8_general_ci": 92, + "geostd8_bin": 93, + "latin1_spanish_ci": 94, + "cp932_japanese_ci": 95, + "cp932_bin": 96, + "eucjpms_japanese_ci": 97, + "eucjpms_bin": 98, + "cp1250_polish_ci": 99, + //"utf16_unicode_ci": 101, + //"utf16_icelandic_ci": 102, + //"utf16_latvian_ci": 103, + //"utf16_romanian_ci": 104, + //"utf16_slovenian_ci": 105, + //"utf16_polish_ci": 106, + //"utf16_estonian_ci": 107, + //"utf16_spanish_ci": 108, + //"utf16_swedish_ci": 109, + //"utf16_turkish_ci": 110, + //"utf16_czech_ci": 111, + //"utf16_danish_ci": 112, + //"utf16_lithuanian_ci": 113, + //"utf16_slovak_ci": 114, + //"utf16_spanish2_ci": 115, + //"utf16_roman_ci": 116, + //"utf16_persian_ci": 117, + //"utf16_esperanto_ci": 118, + //"utf16_hungarian_ci": 119, + //"utf16_sinhala_ci": 120, + //"utf16_german2_ci": 121, + //"utf16_croatian_ci": 122, + //"utf16_unicode_520_ci": 123, + //"utf16_vietnamese_ci": 124, + //"ucs2_unicode_ci": 128, + //"ucs2_icelandic_ci": 129, + //"ucs2_latvian_ci": 130, + //"ucs2_romanian_ci": 131, + //"ucs2_slovenian_ci": 132, + //"ucs2_polish_ci": 133, + //"ucs2_estonian_ci": 134, + //"ucs2_spanish_ci": 135, + //"ucs2_swedish_ci": 136, + //"ucs2_turkish_ci": 137, + //"ucs2_czech_ci": 138, + //"ucs2_danish_ci": 139, + //"ucs2_lithuanian_ci": 140, + //"ucs2_slovak_ci": 141, + //"ucs2_spanish2_ci": 142, + //"ucs2_roman_ci": 143, + //"ucs2_persian_ci": 144, + //"ucs2_esperanto_ci": 145, + //"ucs2_hungarian_ci": 146, + //"ucs2_sinhala_ci": 147, + //"ucs2_german2_ci": 148, + //"ucs2_croatian_ci": 149, + //"ucs2_unicode_520_ci": 150, + //"ucs2_vietnamese_ci": 151, + //"ucs2_general_mysql500_ci": 159, + //"utf32_unicode_ci": 160, + //"utf32_icelandic_ci": 161, + //"utf32_latvian_ci": 162, + //"utf32_romanian_ci": 163, + //"utf32_slovenian_ci": 164, + //"utf32_polish_ci": 165, + //"utf32_estonian_ci": 166, + //"utf32_spanish_ci": 167, + //"utf32_swedish_ci": 168, + //"utf32_turkish_ci": 169, + //"utf32_czech_ci": 170, + //"utf32_danish_ci": 171, + //"utf32_lithuanian_ci": 172, + //"utf32_slovak_ci": 173, + //"utf32_spanish2_ci": 174, + //"utf32_roman_ci": 175, + //"utf32_persian_ci": 176, + //"utf32_esperanto_ci": 177, + //"utf32_hungarian_ci": 178, + //"utf32_sinhala_ci": 179, + //"utf32_german2_ci": 180, + //"utf32_croatian_ci": 181, + //"utf32_unicode_520_ci": 182, + //"utf32_vietnamese_ci": 183, "utf8_unicode_ci": 192, "utf8_icelandic_ci": 193, "utf8_latvian_ci": 194, @@ -233,18 +241,25 @@ var collations = map[string]byte{ "utf8mb4_croatian_ci": 245, "utf8mb4_unicode_520_ci": 246, "utf8mb4_vietnamese_ci": 247, + "gb18030_chinese_ci": 248, + "gb18030_bin": 249, + "gb18030_unicode_520_ci": 250, + "utf8mb4_0900_ai_ci": 255, } // A blacklist of collations which is unsafe to interpolate parameters. // These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. -var unsafeCollations = map[byte]bool{ - 1: true, // big5_chinese_ci - 13: true, // sjis_japanese_ci - 28: true, // gbk_chinese_ci - 84: true, // big5_bin - 86: true, // gb2312_bin - 87: true, // gbk_bin - 88: true, // sjis_bin - 95: true, // cp932_japanese_ci - 96: true, // cp932_bin +var unsafeCollations = map[string]bool{ + "big5_chinese_ci": true, + "sjis_japanese_ci": true, + "gbk_chinese_ci": true, + "big5_bin": true, + "gb2312_bin": true, + "gbk_bin": true, + "sjis_bin": true, + "cp932_japanese_ci": true, + "cp932_bin": true, + "gb18030_chinese_ci": true, + "gb18030_bin": true, + "gb18030_unicode_520_ci": true, } diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck.go b/vendor/github.com/go-sql-driver/mysql/conncheck.go new file mode 100644 index 0000000..024eb28 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/conncheck.go @@ -0,0 +1,54 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package mysql + +import ( + "errors" + "io" + "net" + "syscall" +) + +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(conn net.Conn) error { + var sysErr error + + sysConn, ok := conn.(syscall.Conn) + if !ok { + return nil + } + rawConn, err := sysConn.SyscallConn() + if err != nil { + return err + } + + err = rawConn.Read(func(fd uintptr) bool { + var buf [1]byte + n, err := syscall.Read(int(fd), buf[:]) + switch { + case n == 0 && err == nil: + sysErr = io.EOF + case n > 0: + sysErr = errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + sysErr = nil + default: + sysErr = err + } + return true + }) + if err != nil { + return err + } + + return sysErr +} diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go b/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go new file mode 100644 index 0000000..ea7fb60 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go @@ -0,0 +1,17 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos + +package mysql + +import "net" + +func connCheck(conn net.Conn) error { + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck_test.go b/vendor/github.com/go-sql-driver/mysql/conncheck_test.go new file mode 100644 index 0000000..5399551 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/conncheck_test.go @@ -0,0 +1,38 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package mysql + +import ( + "testing" + "time" +) + +func TestStaleConnectionChecks(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("SET @@SESSION.wait_timeout = 2") + + if err := dbt.db.Ping(); err != nil { + dbt.Fatal(err) + } + + // wait for MySQL to close our connection + time.Sleep(3 * time.Second) + + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + dbt.Fatal(err) + } + }) +} diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index 72ed09d..b07cd76 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -9,9 +9,11 @@ package mysql import ( - "crypto/tls" + "context" + "database/sql" "database/sql/driver" - "errors" + "encoding/json" + "io" "net" "strconv" "strings" @@ -21,40 +23,31 @@ import ( type mysqlConn struct { buf buffer netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. affectedRows uint64 insertId uint64 - cfg *config - maxPacketAllowed int + cfg *Config + maxAllowedPacket int maxWriteSize int + writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 parseTime bool - strict bool -} + reset bool // set when the Go SQL package calls ResetSession -type config struct { - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - tls *tls.Config - timeout time.Duration - collation uint8 - allowAllFiles bool - allowOldPasswords bool - allowCleartextPasswords bool - clientFoundRows bool - columnsWithAlias bool - interpolateParams bool + // for context support (Go 1.8+) + watching bool + watcher chan<- context.Context + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { - for param, val := range mc.cfg.params { + for param, val := range mc.cfg.Params { switch param { // Charset case "charset": @@ -70,27 +63,6 @@ func (mc *mysqlConn) handleParams() (err error) { return } - // time.Time parsing - case "parseTime": - var isBool bool - mc.parseTime, isBool = readBool(val) - if !isBool { - return errors.New("Invalid Bool value: " + val) - } - - // Strict mode - case "strict": - var isBool bool - mc.strict, isBool = readBool(val) - if !isBool { - return errors.New("Invalid Bool value: " + val) - } - - // Compression - case "compress": - err = errors.New("Compression not implemented yet") - return - // System Vars default: err = mc.exec("SET " + param + "=" + val + "") @@ -103,46 +75,89 @@ func (mc *mysqlConn) handleParams() (err error) { return } +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.netConn == nil { + return mc.begin(false) +} + +func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + var q string + if readOnly { + q = "START TRANSACTION READ ONLY" + } else { + q = "START TRANSACTION" + } + err := mc.exec(q) if err == nil { return &mysqlTx{mc}, err } - - return nil, err + return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if mc.netConn != nil { + if !mc.closed.IsSet() { err = mc.writeCommandPacket(comQuit) - if err == nil { - err = mc.netConn.Close() - } else { - mc.netConn.Close() - } - mc.netConn = nil } - mc.cfg = nil - mc.buf.rd = nil + mc.cleanup() return } -func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { +// Closes the network connection and unsets internal variables. Do not call this +// function after successfully authentication, call Close instead. This function +// is called before auth or on auth failure because MySQL will have already +// closed the network connection. +func (mc *mysqlConn) cleanup() { + if !mc.closed.TrySet(true) { + return + } + + // Makes cleanup idempotent + close(mc.closech) if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) error() error { + if mc.closed.IsSet() { + if err := mc.canceled.Value(); err != nil { + return err + } + return ErrInvalidConn + } + return nil +} + +func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { - return nil, err + // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. + errLog.Print(err) + return nil, driver.ErrBadConn } stmt := &mysqlStmt{ @@ -167,11 +182,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { - buf := mc.buf.takeCompleteBuffer() - if buf == nil { + // Number of ? should be same to len(args) + if strings.Count(query, "?") != len(args) { + return "", driver.ErrSkip + } + + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return "", driver.ErrBadConn + errLog.Print(err) + return "", ErrInvalidConn } buf = buf[:0] argPos := 0 @@ -196,6 +216,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin switch v := arg.(type) { case int64: buf = strconv.AppendInt(buf, v, 10) + case uint64: + // Handle uint64 explicitly because our custom ConvertValue emits unsigned values + buf = strconv.AppendUint(buf, v, 10) case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: @@ -208,7 +231,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { - v := v.In(mc.cfg.loc) + v := v.In(mc.cfg.Loc) v = v.Add(time.Nanosecond * 500) // To round under microsecond year := v.Year() year100 := year / 100 @@ -249,6 +272,14 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } buf = append(buf, '\'') } + case json.RawMessage: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') case []byte: if v == nil { buf = append(buf, "NULL"...) @@ -273,7 +304,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - if len(buf)+4 > mc.maxPacketAllowed { + if len(buf)+4 > mc.maxAllowedPacket { return "", driver.ErrSkip } } @@ -284,12 +315,12 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { - if !mc.cfg.interpolateParams { + if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement @@ -298,7 +329,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, err } query = prepared - args = nil } mc.affectedRows = 0 mc.insertId = 0 @@ -310,37 +340,48 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err insertId: int64(mc.insertId), }, err } - return nil, err + return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command - err := mc.writeCommandPacketStr(comQuery, query) - if err != nil { - return err + if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + return mc.markBadConn(err) } // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil && resLen > 0 { - if err = mc.readUntilEOF(); err != nil { + if err != nil { + return err + } + + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { return err } - err = mc.readUntilEOF() + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } } - return err + return mc.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.netConn == nil { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { - if !mc.cfg.interpolateParams { + if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip @@ -349,7 +390,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, err } query = prepared - args = nil } // Send command err := mc.writeCommandPacketStr(comQuery, query) @@ -362,15 +402,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro rows.mc = mc if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } } + // Columns - rows.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } - return nil, err + return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable @@ -386,6 +433,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if err == nil { rows := new(textRows) rows.mc = mc + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns @@ -401,3 +449,212 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } return nil, err } + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.canceled.Set(err) + mc.cleanup() +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return mc.markBadConn(err) + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. + if ctx.Done() == nil { + return nil + } + // When watcher is not alive, can't watch it. + if mc.watcher == nil { + return nil + } + + mc.watching = true + mc.watcher <- ctx + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan context.Context, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx context.Context + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + mc.reset = true + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/connection_test.go b/vendor/github.com/go-sql-driver/mysql/connection_test.go new file mode 100644 index 0000000..a6d6773 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/connection_test.go @@ -0,0 +1,203 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "encoding/json" + "errors" + "net" + "testing" +) + +func TestInterpolateParams(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT 42+'gopher'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + +func TestInterpolateParamsJSONRawMessage(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + buf, err := json.Marshal(struct { + Value int `json:"value"` + }{Value: 42}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + q, err := mc.interpolateParams("SELECT ?", []driver.Value{json.RawMessage(buf)}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT '{\"value\":42}'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + +func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +// We don't support placeholder in string literal for now. +// https://github.com/go-sql-driver/mysql/pull/490 +func TestInterpolateParamsPlaceholderInString(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) + // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +func TestInterpolateParamsUint64(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)}) + if err != nil { + t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q) + } + if q != "SELECT 42" { + t.Errorf("Expected uint64 interpolation to work, got q=%#v", q) + } +} + +func TestCheckNamedValue(t *testing.T) { + value := driver.NamedValue{Value: ^uint64(0)} + x := &mysqlConn{} + err := x.CheckNamedValue(&value) + + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if value.Value != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value) + } +} + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +} + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/connector.go b/vendor/github.com/go-sql-driver/mysql/connector.go new file mode 100644 index 0000000..d567b4e --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/connector.go @@ -0,0 +1,146 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "net" +) + +type connector struct { + cfg *Config // immutable private copy. +} + +// Connect implements driver.Connector interface. +// Connect returns a connection to the database. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: c.cfg, + } + mc.parseTime = mc.cfg.ParseTime + + // Connect to Server + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + mc.netConn, err = dial(dctx, mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + } + + if err != nil { + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + // Call startWatcher for context support (From Go 1.8) + mc.startWatcher() + if err := mc.watchCancel(ctx); err != nil { + mc.cleanup() + return nil, err + } + defer mc.finish() + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + authData, plugin, err := mc.readHandshakePacket() + if err != nil { + mc.cleanup() + return nil, err + } + + if plugin == "" { + plugin = defaultAuthPlugin + } + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = mc.handleAuthResult(authData, plugin); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +// Driver implements driver.Connector interface. +// Driver returns &MySQLDriver{}. +func (c *connector) Driver() driver.Driver { + return &MySQLDriver{} +} diff --git a/vendor/github.com/go-sql-driver/mysql/connector_test.go b/vendor/github.com/go-sql-driver/mysql/connector_test.go new file mode 100644 index 0000000..976903c --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/connector_test.go @@ -0,0 +1,30 @@ +package mysql + +import ( + "context" + "net" + "testing" + "time" +) + +func TestConnectorReturnsTimeout(t *testing.T) { + connector := &connector{&Config{ + Net: "tcp", + Addr: "1.1.1.1:1234", + Timeout: 10 * time.Millisecond, + }} + + _, err := connector.Connect(context.Background()) + if err == nil { + t.Fatal("error expected") + } + + if nerr, ok := err.(*net.OpError); ok { + expected := "dial tcp 1.1.1.1:1234: i/o timeout" + if nerr.Error() != expected { + t.Fatalf("expected %q, got %q", expected, nerr.Error()) + } + } else { + t.Fatalf("expected %T, got %T", nerr, err) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/const.go b/vendor/github.com/go-sql-driver/mysql/const.go index dddc129..b1e6b85 100644 --- a/vendor/github.com/go-sql-driver/mysql/const.go +++ b/vendor/github.com/go-sql-driver/mysql/const.go @@ -9,7 +9,9 @@ package mysql const ( - minProtocolVersion byte = 10 + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) @@ -18,10 +20,11 @@ const ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html const ( - iOK byte = 0x00 - iLocalInFile byte = 0xfb - iEOF byte = 0xfe - iERR byte = 0xff + iOK byte = 0x00 + iAuthMoreData byte = 0x01 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags @@ -87,8 +90,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + const ( - fieldTypeDecimal byte = iota + fieldTypeDecimal fieldType = iota fieldTypeTiny fieldTypeShort fieldTypeLong @@ -107,7 +112,8 @@ const ( fieldTypeBit ) const ( - fieldTypeNewDecimal byte = iota + 0xf6 + fieldTypeJSON fieldType = iota + 0xf5 + fieldTypeNewDecimal fieldTypeEnum fieldTypeSet fieldTypeTinyBLOB @@ -160,3 +166,9 @@ const ( statusInTransReadonly statusSessionStateChanged ) + +const ( + cachingSha2PasswordRequestPublicKey = 2 + cachingSha2PasswordFastAuthSuccess = 3 + cachingSha2PasswordPerformFullAuthentication = 4 +) diff --git a/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/go-sql-driver/mysql/driver.go index d310624..c1bdf11 100644 --- a/vendor/github.com/go-sql-driver/mysql/driver.go +++ b/vendor/github.com/go-sql-driver/mysql/driver.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// Package mysql provides a MySQL driver for Go's database/sql package. // // The driver should be used via the database/sql package: // @@ -17,133 +17,91 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" "net" + "sync" ) -// This struct is exported to make the driver directly accessible. +// MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} // DialFunc is a function which can be used to establish the network connection. // Custom dial functions must be registered with RegisterDial +// +// Deprecated: users should register a DialContextFunc instead type DialFunc func(addr string) (net.Conn, error) -var dials map[string]DialFunc +// DialContextFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDialContext +type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error) -// RegisterDial registers a custom dial function. It can then be used by the +var ( + dialsLock sync.RWMutex + dials map[string]DialContextFunc +) + +// RegisterDialContext registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. -// addr is passed as a parameter to the dial function. -func RegisterDial(net string, dial DialFunc) { +// The current context for the connection and its address is passed to the dial function. +func RegisterDialContext(net string, dial DialContextFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() if dials == nil { - dials = make(map[string]DialFunc) + dials = make(map[string]DialContextFunc) } dials[net] = dial } +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +// +// Deprecated: users should call RegisterDialContext instead +func RegisterDial(network string, dial DialFunc) { + RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) { + return dial(addr) + }) +} + // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated +// the DSN string is formatted func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { - var err error - - // New mysqlConn - mc := &mysqlConn{ - maxPacketAllowed: maxPacketSize, - maxWriteSize: maxPacketSize - 1, - } - mc.cfg, err = parseDSN(dsn) + cfg, err := ParseDSN(dsn) if err != nil { return nil, err } - - // Connect to Server - if dial, ok := dials[mc.cfg.net]; ok { - mc.netConn, err = dial(mc.cfg.addr) - } else { - nd := net.Dialer{Timeout: mc.cfg.timeout} - mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr) + c := &connector{ + cfg: cfg, } - if err != nil { - return nil, err - } - - // Enable TCP Keepalives on TCP connections - if tc, ok := mc.netConn.(*net.TCPConn); ok { - if err := tc.SetKeepAlive(true); err != nil { - // Don't send COM_QUIT before handshake. - mc.netConn.Close() - mc.netConn = nil - return nil, err - } - } - - mc.buf = newBuffer(mc.netConn) - - // Reading Handshake Initialization Packet - cipher, err := mc.readInitPacket() - if err != nil { - mc.Close() - return nil, err - } - - // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { - mc.Close() - return nil, err - } - - // Read Result Packet - err = mc.readResultOK() - if err != nil { - // Retry with old authentication method, if allowed - if mc.cfg != nil && mc.cfg.allowOldPasswords && err == ErrOldPassword { - if err = mc.writeOldAuthPacket(cipher); err != nil { - mc.Close() - return nil, err - } - if err = mc.readResultOK(); err != nil { - mc.Close() - return nil, err - } - } else if mc.cfg != nil && mc.cfg.allowCleartextPasswords && err == ErrCleartextPassword { - if err = mc.writeClearAuthPacket(); err != nil { - mc.Close() - return nil, err - } - if err = mc.readResultOK(); err != nil { - mc.Close() - return nil, err - } - } else { - mc.Close() - return nil, err - } - - } - - // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") - if err != nil { - mc.Close() - return nil, err - } - mc.maxPacketAllowed = stringToInt(maxap) - 1 - if mc.maxPacketAllowed < maxPacketSize { - mc.maxWriteSize = mc.maxPacketAllowed - } - - // Handle DSN Params - err = mc.handleParams() - if err != nil { - mc.Close() - return nil, err - } - - return mc, nil + return c.Connect(context.Background()) } func init() { sql.Register("mysql", &MySQLDriver{}) } + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + cfg = cfg.Clone() + // normalize the contents of cfg so calls to NewConnector have the same + // behavior as MySQLDriver.OpenConnector + if err := cfg.normalize(); err != nil { + return nil, err + } + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/driver_test.go b/vendor/github.com/go-sql-driver/mysql/driver_test.go new file mode 100644 index 0000000..ace083d --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/driver_test.go @@ -0,0 +1,3165 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "context" + "crypto/tls" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "net/url" + "os" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Ensure that all the driver interfaces are implemented +var ( + _ driver.Rows = &binaryRows{} + _ driver.Rows = &textRows{} +) + +var ( + user string + pass string + prot string + addr string + dbname string + dsn string + netAddr string + available bool +) + +var ( + tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) + sDate = "2012-06-14" + tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC) + sDateTime = "2011-11-20 21:27:37" + tDate0 = time.Time{} + sDate0 = "0000-00-00" + sDateTime0 = "0000-00-00 00:00:00" +) + +// See https://github.com/go-sql-driver/mysql/wiki/Testing +func init() { + // get environment variables + env := func(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue + } + user = env("MYSQL_TEST_USER", "root") + pass = env("MYSQL_TEST_PASS", "") + prot = env("MYSQL_TEST_PROT", "tcp") + addr = env("MYSQL_TEST_ADDR", "localhost:3306") + dbname = env("MYSQL_TEST_DBNAME", "gotest") + netAddr = fmt.Sprintf("%s(%s)", prot, addr) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname) + c, err := net.Dial(prot, addr) + if err == nil { + available = true + c.Close() + } +} + +type DBTest struct { + *testing.T + db *sql.DB +} + +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + +func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn += "&multiStatements=true" + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + } +} + +func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + db.Exec("DROP TABLE IF EXISTS test") + + dsn2 := dsn + "&interpolateParams=true" + var db2 *sql.DB + if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + db2, err = sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db2.Close() + } + + dsn3 := dsn + "&multiStatements=true" + var db3 *sql.DB + if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + db3, err = sql.Open("mysql", dsn3) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db3.Close() + } + + dbt := &DBTest{t, db} + dbt2 := &DBTest{t, db2} + dbt3 := &DBTest{t, db3} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + if db2 != nil { + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + } + if db3 != nil { + test(dbt3) + dbt3.db.Exec("DROP TABLE IF EXISTS test") + } + } +} + +func (dbt *DBTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) +} + +func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + res, err := dbt.db.Exec(query, args...) + if err != nil { + dbt.fail("exec", query, err) + } + return res +} + +func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + rows, err := dbt.db.Query(query, args...) + if err != nil { + dbt.fail("query", query, err) + } + return rows +} + +func maybeSkip(t *testing.T, err error, skipErrno uint16) { + mySQLErr, ok := err.(*MySQLError) + if !ok { + return + } + + if mySQLErr.Number == skipErrno { + t.Skipf("skipping test for error: %v", err) + } +} + +func TestEmptyQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // just a comment, no query + rows := dbt.mustQuery("--") + defer rows.Close() + // will hang before #255 + if rows.Next() { + dbt.Errorf("next on rows must be false") + } + }) +} + +func TestCRUD(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + + // Test for unexpected data + var out bool + rows := dbt.mustQuery("SELECT * FROM test") + if rows.Next() { + dbt.Error("unexpected data in empty table") + } + rows.Close() + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) + } + if id != 0 { + dbt.Fatalf("expected InsertId 0, got %d", id) + } + + // Read + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if true != out { + dbt.Errorf("true != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Update + res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check Update + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if false != out { + dbt.Errorf("false != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Delete + res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check for unexpected rows + res = dbt.mustExec("DELETE FROM test") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 0 { + dbt.Fatalf("expected 0 affected row, got %d", count) + } + }) +} + +func TestMultiQuery(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Update + res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Read + var out int + rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") + if rows.Next() { + rows.Scan(&out) + if 5 != out { + dbt.Errorf("5 != %d", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + }) +} + +func TestInt(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} + in := int64(42) + var out int64 + var rows *sql.Rows + + // SIGNED + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // UNSIGNED ZEROFILL + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s ZEROFILL: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat32(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + in := float32(42.23) + var out float32 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64Placeholder(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") + rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestString(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} + in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" + var out string + var rows *sql.Rows + + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // BLOB + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + + id := 2 + in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + + "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." + dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) + + err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + if err != nil { + dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) + } else if out != in { + dbt.Errorf("BLOB: %s != %s", in, out) + } + }) +} + +func TestRawBytes(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := []byte("aaa") + v2 := []byte("bbb") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 sql.RawBytes + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + // https://github.com/go-sql-driver/mysql/issues/765 + // Appending to RawBytes shouldn't overwrite next RawBytes. + o1 = append(o1, "xyzzy"...) + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + +type testValuer struct { + value string +} + +func (tv testValuer) Value() (driver.Value, error) { + return tv.value, nil +} + +func TestValuer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuer{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + +type testValuerWithValidation struct { + value string +} + +func (tv testValuerWithValidation) Value() (driver.Value, error) { + if len(tv.value) == 0 { + return nil, fmt.Errorf("Invalid string valuer. Value must not be empty") + } + + return tv.value, nil +} + +func TestValuerWithValidation(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuerWithValidation{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM testValuer") + defer rows.Close() + + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + dbt.Errorf("Failed to check valuer error") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + dbt.Errorf("Failed to check nil") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + dbt.Errorf("Failed to check not valuer") + } + + dbt.mustExec("DROP TABLE IF EXISTS testValuer") + }) +} + +type timeTests struct { + dbtype string + tlayout string + tests []timeTest +} + +type timeTest struct { + s string // leading "!": do not use t as value in queries + t time.Time +} + +type timeMode byte + +func (t timeMode) String() string { + switch t { + case binaryString: + return "binary:string" + case binaryTime: + return "binary:time.Time" + case textString: + return "text:string" + } + panic("unsupported timeMode") +} + +func (t timeMode) Binary() bool { + switch t { + case binaryString, binaryTime: + return true + } + return false +} + +const ( + binaryString timeMode = iota + binaryTime + textString +) + +func (t timeTest) genQuery(dbtype string, mode timeMode) string { + var inner string + if mode.Binary() { + inner = "?" + } else { + inner = `"%s"` + } + return `SELECT cast(` + inner + ` as ` + dbtype + `)` +} + +func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { + var rows *sql.Rows + query := t.genQuery(dbtype, mode) + switch mode { + case binaryString: + rows = dbt.mustQuery(query, t.s) + case binaryTime: + rows = dbt.mustQuery(query, t.t) + case textString: + query = fmt.Sprintf(query, t.s) + rows = dbt.mustQuery(query) + default: + panic("unsupported mode") + } + defer rows.Close() + var err error + if !rows.Next() { + err = rows.Err() + if err == nil { + err = fmt.Errorf("no data") + } + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + var dst interface{} + err = rows.Scan(&dst) + if err != nil { + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + switch val := dst.(type) { + case []uint8: + str := string(val) + if str == t.s { + return + } + if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s { + // a fix mainly for TravisCI: + // accept full microsecond resolution in result for DATETIME columns + // where the binary protocol was used + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, str, + ) + case time.Time: + if val == t.t { + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, val.Format(tlayout), + ) + default: + fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t}) + dbt.Errorf("%s [%s]: unhandled type %T (is '%v')", + dbtype, mode, + val, val, + ) + } +} + +func TestDateTime(t *testing.T) { + afterTime := func(t time.Time, d string) time.Time { + dur, err := time.ParseDuration(d) + if err != nil { + panic(err) + } + return t.Add(dur) + } + // NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests + format := "2006-01-02 15:04:05.999999" + t0 := time.Time{} + tstr0 := "0000-00-00 00:00:00.000000" + testcases := []timeTests{ + {"DATE", format[:10], []timeTest{ + {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, + {t: t0, s: tstr0[:10]}, + }}, + {"DATETIME", format[:19], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(0)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(1)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, + {t: t0, s: tstr0[:21]}, + }}, + {"DATETIME(6)", format, []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, + {t: t0, s: tstr0}, + }}, + {"TIME", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(0)", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(1)", format[11:21], []timeTest{ + {t: afterTime(t0, "12345600ms")}, + {s: "!-12:34:56.7"}, + {s: "!-838:59:58.9"}, + {s: "!838:59:58.9"}, + {t: t0, s: tstr0[11:21]}, + }}, + {"TIME(6)", format[11:], []timeTest{ + {t: afterTime(t0, "1234567890123000ns")}, + {s: "!-12:34:56.789012"}, + {s: "!-838:59:58.999999"}, + {s: "!838:59:58.999999"}, + {t: t0, s: tstr0[11:]}, + }}, + } + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + microsecsSupported := false + zeroDateSupported := false + var rows *sql.Rows + var err error + rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) + if err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) + if err == nil { + rows.Scan(&zeroDateSupported) + rows.Close() + } + for _, setups := range testcases { + if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" { + // skip fractional second tests if unsupported by server + continue + } + for _, setup := range setups.tests { + allowBinTime := true + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } else if setup.s[0] == '!' { + // skip tests using setup.t as source in queries + allowBinTime = false + // fix setup.s - remove the "!" + setup.s = setup.s[1:] + } + if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] { + // skip disallowed 0000-00-00 date + continue + } + setup.run(dbt, setups.dbtype, setups.tlayout, textString) + setup.run(dbt, setups.dbtype, setups.tlayout, binaryString) + if allowBinTime { + setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime) + } + } + } + }) + } +} + +func TestTimestampMicros(t *testing.T) { + format := "2006-01-02 15:04:05.999999" + f0 := format[:19] + f1 := format[:21] + f6 := format[:26] + runTests(t, dsn, func(dbt *DBTest) { + // check if microseconds are supported. + // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width + // and not precision. + // Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html + microsecsSupported := false + if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + if !microsecsSupported { + // skip test + return + } + _, err := dbt.db.Exec(` + CREATE TABLE test ( + value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', + value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', + value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' + )`, + ) + if err != nil { + dbt.Error(err) + } + defer dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) + var res0, res1, res6 string + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Errorf("test contained no selectable values") + } + err = rows.Scan(&res0, &res1, &res6) + if err != nil { + dbt.Error(err) + } + if res0 != f0 { + dbt.Errorf("expected %q, got %q", f0, res0) + } + if res1 != f1 { + dbt.Errorf("expected %q, got %q", f1, res1) + } + if res6 != f6 { + dbt.Errorf("expected %q, got %q", f6, res6) + } + }) +} + +func TestNULL(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + nullStmt, err := dbt.db.Prepare("SELECT NULL") + if err != nil { + dbt.Fatal(err) + } + defer nullStmt.Close() + + nonNullStmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + defer nonNullStmt.Close() + + // NullBool + var nb sql.NullBool + // Invalid + if err = nullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if nb.Valid { + dbt.Error("valid NullBool which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if !nb.Valid { + dbt.Error("invalid NullBool which should be valid") + } else if nb.Bool != true { + dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) + } + + // NullFloat64 + var nf sql.NullFloat64 + // Invalid + if err = nullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if nf.Valid { + dbt.Error("valid NullFloat64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if !nf.Valid { + dbt.Error("invalid NullFloat64 which should be valid") + } else if nf.Float64 != float64(1) { + dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) + } + + // NullInt64 + var ni sql.NullInt64 + // Invalid + if err = nullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if ni.Valid { + dbt.Error("valid NullInt64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if !ni.Valid { + dbt.Error("invalid NullInt64 which should be valid") + } else if ni.Int64 != int64(1) { + dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) + } + + // NullString + var ns sql.NullString + // Invalid + if err = nullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if ns.Valid { + dbt.Error("valid NullString which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if !ns.Valid { + dbt.Error("invalid NullString which should be valid") + } else if ns.String != `1` { + dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") + } + + // nil-bytes + var b []byte + // Read nil + if err = nullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil []byte which should be nil") + } + // Read non-nil + if err = nonNullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil []byte which should be non-nil") + } + // Insert nil + b = nil + success := false + if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + dbt.Fatal(err) + } + if !success { + dbt.Error("inserting []byte(nil) as NULL failed") + } + // Check input==output with input==nil + b = nil + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil echo from nil input") + } + // Check input==output with input!=nil + b = []byte("") + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil echo from non-nil input") + } + + // Insert NULL + dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + + dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) + + var out interface{} + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if out != nil { + dbt.Errorf("%v != nil", out) + } + } else { + dbt.Error("no data") + } + }) +} + +func TestUint64(t *testing.T) { + const ( + u0 = uint64(0) + uall = ^u0 + uhigh = uall >> 1 + utop = ^uhigh + s0 = int64(0) + sall = ^s0 + shigh = int64(uhigh) + stop = ^shigh + ) + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + row := stmt.QueryRow( + u0, uhigh, utop, uall, + s0, shigh, stop, sall, + ) + + var ua, ub, uc, ud uint64 + var sa, sb, sc, sd int64 + + err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd) + if err != nil { + dbt.Fatal(err) + } + switch { + case ua != u0, + ub != uhigh, + uc != utop, + ud != uall, + sa != s0, + sb != shigh, + sc != stop, + sd != sall: + dbt.Fatal("unexpected result value") + } + }) +} + +func TestLongData(t *testing.T) { + runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { + var maxAllowedPacketSize int + err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) + if err != nil { + dbt.Fatal(err) + } + maxAllowedPacketSize-- + + // don't get too ambitious + if maxAllowedPacketSize > 1<<25 { + maxAllowedPacketSize = 1 << 25 + } + + dbt.mustExec("CREATE TABLE test (value LONGBLOB)") + + in := strings.Repeat(`a`, maxAllowedPacketSize+1) + var out string + var rows *sql.Rows + + // Long text data + const nonDataQueryLen = 28 // length query w/o value + inS := in[:maxAllowedPacketSize-nonDataQueryLen] + dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if inS != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + dbt.Fatalf("LONGBLOB: no data") + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Long binary data + dbt.mustExec("INSERT INTO test VALUES(?)", in) + rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + if err = rows.Err(); err != nil { + dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error()) + } else { + dbt.Fatal("LONGBLOB: no data (err: )") + } + } + }) +} + +func TestLoadData(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + verifyLoadDataResult := func() { + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + dbt.Fatal(err.Error()) + } + + i := 0 + values := [4]string{ + "a string", + "a string containing a \t", + "a string containing a \n", + "a string containing both \t\n", + } + + var id int + var value string + + for rows.Next() { + i++ + err = rows.Scan(&id, &value) + if err != nil { + dbt.Fatal(err.Error()) + } + if i != id { + dbt.Fatalf("%d != %d", i, id) + } + if values[i-1] != value { + dbt.Fatalf("%q != %q", values[i-1], value) + } + } + err = rows.Err() + if err != nil { + dbt.Fatal(err.Error()) + } + + if i != 4 { + dbt.Fatalf("rows count mismatch. Got %d, want 4", i) + } + } + + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") + + // Local File + file, err := ioutil.TempFile("", "gotest") + defer os.Remove(file.Name()) + if err != nil { + dbt.Fatal(err) + } + RegisterLocalFile(file.Name()) + + // Try first with empty file + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + var count int + err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&count) + if err != nil { + dbt.Fatal(err.Error()) + } + if count != 0 { + dbt.Fatalf("unexpected row count: got %d, want 0", count) + } + + // Then fille File with data and try to load it + file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") + file.Close() + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + verifyLoadDataResult() + + // Try with non-existing file + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent file didn't fail") + } else if err.Error() != "local file 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Reader + RegisterReaderHandler("test", func() io.Reader { + file, err = os.Open(file.Name()) + if err != nil { + dbt.Fatal(err) + } + return file + }) + dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent Reader didn't fail") + } else if err.Error() != "Reader 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + }) +} + +func TestFoundRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + }) + runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 matched rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 3 { + dbt.Fatalf("Expected 3 matched rows, got %d", count) + } + }) +} + +func TestTLS(t *testing.T) { + tlsTestReq := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + if err == ErrNoTLS { + dbt.Skip("server does not support TLS") + } else { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") + defer rows.Close() + + var variable, value *sql.RawBytes + for rows.Next() { + if err := rows.Scan(&variable, &value); err != nil { + dbt.Fatal(err.Error()) + } + + if (*value == nil) || (len(*value) == 0) { + dbt.Fatalf("no Cipher") + } else { + dbt.Logf("Cipher: %s", *value) + } + } + } + tlsTestOpt := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + runTests(t, dsn+"&tls=preferred", tlsTestOpt) + runTests(t, dsn+"&tls=skip-verify", tlsTestReq) + + // Verify that registering / using a custom cfg works + RegisterTLSConfig("custom-skip-verify", &tls.Config{ + InsecureSkipVerify: true, + }) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) +} + +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + md := &MySQLDriver{} + conn, err := md.Open(dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("error closing connection: %s", err.Error()) + } + + defer func() { + if err := recover(); err != nil { + t.Errorf("panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != driver.ErrBadConn { + t.Errorf("unexpected error '%s', expected '%s'", + err.Error(), driver.ErrBadConn.Error()) + } +} + +func TestCharset(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mustSetCharset := func(charsetParam, expected string) { + runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() + + if !rows.Next() { + dbt.Fatalf("error getting connection charset: %s", rows.Err()) + } + + var got string + rows.Scan(&got) + + if got != expected { + dbt.Fatalf("expected connection charset %s but got %s", expected, got) + } + }) + } + + // non utf8 test + mustSetCharset("charset=ascii", "ascii") + + // when the first charset is invalid, use the second + mustSetCharset("charset=none,utf8", "utf8") + + // when the first charset is valid, use it + mustSetCharset("charset=ascii,utf8", "ascii") + mustSetCharset("charset=utf8,ascii", "utf8") +} + +func TestFailingCharset(t *testing.T) { + runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + // run query to really establish connection... + _, err := dbt.db.Exec("SELECT 1") + if err == nil { + dbt.db.Close() + t.Fatalf("connection must not succeed without a valid charset") + } + }) +} + +func TestCollation(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + defaultCollation := "utf8mb4_general_ci" + testCollations := []string{ + "", // do not set + defaultCollation, // driver default + "latin1_general_ci", + "binary", + "utf8_unicode_ci", + "cp1257_bin", + } + + for _, collation := range testCollations { + var expected, tdsn string + if collation != "" { + tdsn = dsn + "&collation=" + collation + expected = collation + } else { + tdsn = dsn + expected = defaultCollation + } + + runTests(t, tdsn, func(dbt *DBTest) { + var got string + if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { + dbt.Fatal(err) + } + + if got != expected { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } + }) + } +} + +func TestColumnsWithAlias(t *testing.T) { + runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1 AS A") + defer rows.Close() + cols, _ := rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A" { + t.Fatalf("expected column name \"A\", got \"%s\"", cols[0]) + } + + rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") + defer rows.Close() + cols, _ = rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A.one" { + t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0]) + } + }) +} + +func TestRawBytesResultExceedsBuffer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // defaultBufSize from buffer.go + expected := strings.Repeat("abc", defaultBufSize) + + rows := dbt.mustQuery("SELECT '" + expected + "'") + defer rows.Close() + if !rows.Next() { + dbt.Error("expected result, got none") + } + var result sql.RawBytes + rows.Scan(&result) + if expected != string(result) { + dbt.Error("result did not match expected value") + } + }) +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + + // Regression test for timezone handling + tzTest := func(dbt *DBTest) { + // Create table + dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") + + // Insert local time into database (should be converted) + usCentral, _ := time.LoadLocation("US/Central") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + dbt.mustExec("INSERT INTO test VALUE (?)", reftime) + + // Retrieve time from DB + rows := dbt.mustQuery("SELECT ts FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Fatal("did not get any rows out") + } + + var dbTime time.Time + err := rows.Scan(&dbTime) + if err != nil { + dbt.Fatal("Err", err) + } + + // Check that dates match + if reftime.Unix() != dbTime.Unix() { + dbt.Errorf("times do not match.\n") + dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(UTC)=%v\n", dbTime) + } + } + + for _, tz := range zones { + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) + } +} + +// Special cases + +func TestRowsClose(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows, err := dbt.db.Query("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + err = rows.Close() + if err != nil { + dbt.Fatal(err) + } + + if rows.Next() { + dbt.Fatal("unexpected row after rows.Close()") + } + + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + }) +} + +// dangling statements +// http://code.google.com/p/go/issues/detail?id=3865 +func TestCloseStmtBeforeRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + rows, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows.Close() + + err = stmt.Close() + if err != nil { + dbt.Fatal(err) + } + + if !rows.Next() { + dbt.Fatal("getting row failed") + } else { + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + + var out bool + err = rows.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + }) +} + +// It is valid to have multiple Rows for the same Stmt +// http://code.google.com/p/go/issues/detail?id=3734 +func TestStmtMultiRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") + if err != nil { + dbt.Fatal(err) + } + + rows1, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows1.Close() + + rows2, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows2.Close() + + var out bool + + // 1 + if !rows1.Next() { + dbt.Fatal("first rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + if !rows2.Next() { + dbt.Fatal("first rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + // 2 + if !rows1.Next() { + dbt.Fatal("second rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows1.Next() { + dbt.Fatal("unexpected row on rows1") + } + err = rows1.Close() + if err != nil { + dbt.Fatal(err) + } + } + + if !rows2.Next() { + dbt.Fatal("second rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows2.Next() { + dbt.Fatal("unexpected row on rows2") + } + err = rows2.Close() + if err != nil { + dbt.Fatal(err) + } + } + }) +} + +// Regression test for +// * more than 32 NULL parameters (issue 209) +// * more parameters than fit into the buffer (issue 201) +// * parameters * 64 > max_allowed_packet (issue 734) +func TestPreparedManyCols(t *testing.T) { + numParams := 65535 + runTests(t, dsn, func(dbt *DBTest) { + query := "SELECT ?" + strings.Repeat(",?", numParams-1) + stmt, err := dbt.db.Prepare(query) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + + // create more parameters than fit into the buffer + // which will take nil-values + params := make([]interface{}, numParams) + rows, err := stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + + // Create 0byte string which we can't send via STMT_LONG_DATA. + for i := 0; i < numParams; i++ { + params[i] = "" + } + rows, err = stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + }) +} + +func TestConcurrent(t *testing.T) { + if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { + t.Skip("MYSQL_TEST_CONCURRENT env var not set") + } + + runTests(t, dsn, func(dbt *DBTest) { + var max int + err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + dbt.Logf("testing up to %d concurrent connections \r\n", max) + + var remaining, succeeded int32 = int32(max), 0 + + var wg sync.WaitGroup + wg.Add(max) + + var fatalError string + var once sync.Once + fatalf := func(s string, vals ...interface{}) { + once.Do(func() { + fatalError = fmt.Sprintf(s, vals...) + }) + } + + for i := 0; i < max; i++ { + go func(id int) { + defer wg.Done() + + tx, err := dbt.db.Begin() + atomic.AddInt32(&remaining, -1) + + if err != nil { + if err.Error() != "Error 1040: Too many connections" { + fatalf("error on conn %d: %s", id, err.Error()) + } + return + } + + // keep the connection busy until all connections are open + for remaining > 0 { + if _, err = tx.Exec("DO 1"); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + } + + if err = tx.Commit(); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + + // everything went fine with this connection + atomic.AddInt32(&succeeded, 1) + }(i) + } + + // wait until all conections are open + wg.Wait() + + if fatalError != "" { + dbt.Fatal(fatalError) + } + + dbt.Logf("reached %d concurrent connections\r\n", succeeded) + }) +} + +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, testErr) +} + +// Tests custom dial functions +func TestCustomDial(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // our custom dial function which justs wraps net.Dial here + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + if _, err = db.Exec("DO 1"); err != nil { + t.Fatalf("connection failed: %s", err.Error()) + } +} + +func TestSQLInjection(t *testing.T) { + createTest := func(arg string) func(dbt *DBTest) { + return func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (?)", 1) + + var v int + // NULL can't be equal to anything, the idea here is to inject query so it returns row + // This test verifies that escapeQuotes and escapeBackslash are working properly + err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) + if err == sql.ErrNoRows { + return // success, sql injection failed + } else if err == nil { + dbt.Errorf("sql injection successful with arg: %s", arg) + } else { + dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) + } + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("1 OR 1=1")) + runTests(t, testdsn, createTest("' OR '1'='1")) + } +} + +// Test if inserted data is correctly retrieved after being escaped +func TestInsertRetrieveEscapedData(t *testing.T) { + testData := func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") + + // All sequences that are escaped by escapeQuotes and escapeBackslash + v := "foo \x00\n\r\x1a\"'\\" + dbt.mustExec("INSERT INTO test VALUES (?)", v) + + var out string + err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if out != v { + dbt.Errorf("%q != %q", out, v) + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, testData) + } +} + +func TestUnixSocketAuthFail(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Save the current logger so we can restore it. + oldLogger := errLog + + // Set a new logger so we can capture its output. + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + newLogger := log.New(buffer, "prefix: ", 0) + SetLogger(newLogger) + + // Restore the logger. + defer SetLogger(oldLogger) + + // Make a new DSN that uses the MySQL socket file and a bad password, which + // we can make by simply appending any character to the real password. + badPass := pass + "x" + socket := "" + if prot == "unix" { + socket = addr + } else { + // Get socket file from MySQL. + err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) + if err != nil { + t.Fatalf("error on SELECT @@socket: %s", err.Error()) + } + } + t.Logf("socket: %s", socket) + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) + db, err := sql.Open("mysql", badDSN) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + // Connect to MySQL for real. This will cause an auth failure. + err = db.Ping() + if err == nil { + t.Error("expected Ping() to return an error") + } + + // The driver should not log anything. + if actual := buffer.String(); actual != "" { + t.Errorf("expected no output, got %q", actual) + } + }) +} + +// See Issue #422 +func TestInterruptBySignal(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + DROP PROCEDURE IF EXISTS test_signal; + CREATE PROCEDURE test_signal(ret INT) + BEGIN + SELECT ret; + SIGNAL SQLSTATE + '45001' + SET + MESSAGE_TEXT = "an error", + MYSQL_ERRNO = 45001; + END + `) + defer dbt.mustExec("DROP PROCEDURE test_signal") + + var val int + + // text protocol + rows, err := dbt.db.Query("CALL test_signal(42)") + if err != nil { + dbt.Fatalf("error on text query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + + // binary protocol + rows, err = dbt.db.Query("CALL test_signal(?)", 42) + if err != nil { + dbt.Fatalf("error on binary query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + }) +} + +func TestColumnsReusesSlice(t *testing.T) { + rows := mysqlRows{ + rs: resultSet{ + columns: []mysqlField{ + { + tableName: "test", + name: "A", + }, + { + tableName: "test", + name: "B", + }, + }, + }, + } + + allocs := testing.AllocsPerRun(1, func() { + cols := rows.Columns() + + if len(cols) != 2 { + t.Fatalf("expected 2 columns, got %d", len(cols)) + } + }) + + if allocs != 0 { + t.Fatalf("expected 0 allocations, got %d", int(allocs)) + } + + if rows.rs.columnNames == nil { + t.Fatalf("expected columnNames to be set, got nil") + } +} + +func TestRejectReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read-only. We didn't set the `rejectReadOnly` + // option, so any writes after this should fail. + _, err := dbt.db.Exec("SET SESSION TRANSACTION READ ONLY") + // Error 1193: Unknown system variable 'TRANSACTION' => skip test, + // MySQL server version is too old + maybeSkip(t, err, 1193) + if _, err := dbt.db.Exec("DROP TABLE test"); err == nil { + t.Fatalf("writing to DB in read-only session without " + + "rejectReadOnly did not error") + } + // Set the session back to read-write so runTests() can properly clean + // up the table `test`. + dbt.mustExec("SET SESSION TRANSACTION READ WRITE") + }) + + // Enable the `rejectReadOnly` option. + runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read only. Any writes after this should error on + // a driver.ErrBadConn, and cause `database/sql` to initiate a new + // connection. + dbt.mustExec("SET SESSION TRANSACTION READ ONLY") + // This would error, but `database/sql` should automatically retry on a + // new connection which is not read-only, and eventually succeed. + dbt.mustExec("DROP TABLE test") + }) +} + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} + +// See Issue #799 +func TestEmptyPassword(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) + db, err := sql.Open("mysql", dsn) + if err == nil { + defer db.Close() + err = db.Ping() + } + + if pass == "" { + if err != nil { + t.Fatal(err.Error()) + } + } else { + if err == nil { + t.Fatal("expected authentication error") + } + if !strings.HasPrefix(err.Error(), "Error 1045") { + t.Fatal(err.Error()) + } + } +} + +// static interface implementation checks of mysqlConn +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) + +// static interface implementation checks of mysqlStmt +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} +) + +// Ensure that all the driver interfaces are implemented +var ( + // _ driver.RowsColumnTypeLength = &binaryRows{} + // _ driver.RowsColumnTypeLength = &textRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &textRows{} + _ driver.RowsColumnTypeNullable = &binaryRows{} + _ driver.RowsColumnTypeNullable = &textRows{} + _ driver.RowsColumnTypePrecisionScale = &binaryRows{} + _ driver.RowsColumnTypePrecisionScale = &textRows{} + _ driver.RowsColumnTypeScanType = &binaryRows{} + _ driver.RowsColumnTypeScanType = &textRows{} + _ driver.RowsNextResultSet = &binaryRows{} + _ driver.RowsNextResultSet = &textRows{} +) + +func TestMultiResultSet(t *testing.T) { + type result struct { + values [][]int + columns []string + } + + // checkRows is a helper test function to validate rows containing 3 result + // sets with specific values and columns. The basic query would look like this: + // + // SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + // SELECT 0 UNION SELECT 1; + // SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + // + // to distinguish test cases the first string argument is put in front of + // every error or fatal message. + checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) { + expected := []result{ + { + values: [][]int{{1, 2}, {3, 4}}, + columns: []string{"col1", "col2"}, + }, + { + values: [][]int{{1, 2, 3}, {4, 5, 6}}, + columns: []string{"col1", "col2", "col3"}, + }, + } + + var res1 result + for rows.Next() { + var res [2]int + if err := rows.Scan(&res[0], &res[1]); err != nil { + dbt.Fatal(err) + } + res1.values = append(res1.values, res[:]) + } + + cols, err := rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res1.columns = cols + + if !reflect.DeepEqual(expected[0], res1) { + dbt.Error(desc, "want =", expected[0], "got =", res1) + } + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + // ignoring one result set + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + var res2 result + cols, err = rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res2.columns = cols + + for rows.Next() { + var res [3]int + if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { + dbt.Fatal(desc, err) + } + res2.values = append(res2.values, res[:]) + } + + if !reflect.DeepEqual(expected[1], res2) { + dbt.Error(desc, "want =", expected[1], "got =", res2) + } + + if rows.NextResultSet() { + dbt.Error(desc, "unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error(desc, err) + } + } + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery(`DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`) + defer rows.Close() + checkRows("query: ", rows, dbt) + }) + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + queries := []string{ + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + } + + defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss") + + for i, query := range queries { + dbt.mustExec(query) + + stmt, err := dbt.db.Prepare("CALL test_mrss()") + if err != nil { + dbt.Fatalf("%v (i=%d)", err, i) + } + defer stmt.Close() + + for j := 0; j < 2; j++ { + rows, err := stmt.Query() + if err != nil { + dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j) + } + checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt) + } + } + }) +} + +func TestMultiResultSetNoSelect(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("DO 1; DO 2;") + defer rows.Close() + + if rows.Next() { + dbt.Error("unexpected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +// tests if rows are set in a proper state if some results were ignored before +// calling rows.NextResultSet. +func TestSkipResults(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1, 2") + defer rows.Close() + + if !rows.Next() { + dbt.Error("expected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +func TestPingContext(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := dbt.db.PingContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + dbt.Error("expected error") + } else if err.Error() != "context canceled" { + dbt.Fatalf("unexpected error: %s", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQueryRow(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + // the first row will be succeed. + var v int + if !rows.Next() { + dbt.Fatalf("unexpected end") + } + if err := rows.Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + + cancel() + // make sure the driver receives the cancel request. + time.Sleep(100 * time.Millisecond) + + if rows.Next() { + dbt.Errorf("expected end, but not") + } + if err := rows.Err(); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelPrepare(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelStmtExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.ExecContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelStmtQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.QueryContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelBegin(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + tx, err := dbt.db.BeginTx(ctx, nil) + if err != nil { + dbt.Fatal(err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Transaction is canceled, so expect an error. + switch err := tx.Commit(); err { + case sql.ErrTxDone: + // because the transaction has already been rollbacked. + // the database/sql package watches ctx + // and rollbacks when ctx is canceled. + case context.Canceled: + // the database/sql package rollbacks on another goroutine, + // so the transaction may not be rollbacked depending on goroutine scheduling. + default: + dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) + } + + // Context is canceled, so cannot begin a transaction. + if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextBeginIsolationLevel(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx1, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + }) + if err != nil { + dbt.Fatal(err) + } + + tx2, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelReadCommitted, + }) + if err != nil { + dbt.Fatal(err) + } + + _, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if err != nil { + dbt.Fatal(err) + } + + var v int + row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Because writer transaction wasn't commited yet, it should be available + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + err = tx1.Commit() + if err != nil { + dbt.Fatal(err) + } + + row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Data written by writer transaction is already commited, it should be selectable + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + tx2.Commit() + }) +} + +func TestContextBeginReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: true, + }) + if _, ok := err.(*MySQLError); ok { + dbt.Skip("It seems that your MySQL does not support READ ONLY transactions") + return + } else if err != nil { + dbt.Fatal(err) + } + + // INSERT queries fail in a READ ONLY transaction. + _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if _, ok := err.(*MySQLError); !ok { + dbt.Errorf("expected MySQLError, got %v", err) + } + + // SELECT queries can be executed. + var v int + row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + if err := tx.Commit(); err != nil { + dbt.Fatal(err) + } + }) +} + +func TestRowsColumnTypes(t *testing.T) { + niNULL := sql.NullInt64{Int64: 0, Valid: false} + ni0 := sql.NullInt64{Int64: 0, Valid: true} + ni1 := sql.NullInt64{Int64: 1, Valid: true} + ni42 := sql.NullInt64{Int64: 42, Valid: true} + nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} + nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} + nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} + nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := NullTime{Time: time.Time{}, Valid: false} + rbNULL := sql.RawBytes(nil) + rb0 := sql.RawBytes("0") + rb42 := sql.RawBytes("42") + rbTest := sql.RawBytes("Test") + rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 + rbx0 := sql.RawBytes("\x00") + rbx42 := sql.RawBytes("\x42") + + var columns = []struct { + name string + fieldType string // type used when creating table schema + databaseTypeName string // actual type used by MySQL + scanType reflect.Type + nullable bool + precision int64 // 0 if not ok + scale int64 + valuesIn [3]string + valuesOut [3]interface{} + }{ + {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, + {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, + {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, + {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}}, + {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, + {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, + {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, + {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, + {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, + {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, + {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, + {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, + } + + schema := "" + values1 := "" + values2 := "" + values3 := "" + for _, column := range columns { + schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType) + values1 += column.valuesIn[0] + ", " + values2 += column.valuesIn[1] + ", " + values3 += column.valuesIn[2] + ", " + } + schema = schema[:len(schema)-2] + values1 = values1[:len(values1)-2] + values2 = values2[:len(values2)-2] + values3 = values3[:len(values3)-2] + + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } + + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] + + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } + + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } + + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) + } + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { + if !ok { + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) + } + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) + } + continue + } + } + + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) + } + } + } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) + } +} + +func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") + dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) + // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() + }) +} + +// TestRawBytesAreNotModified checks for a race condition that arises when a query context +// is canceled while a user is calling rows.Scan. This is a more stringent test than the one +// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using +// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit +// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers. +func TestRawBytesAreNotModified(t *testing.T) { + const blob = "abcdefghijklmnop" + const contextRaceIterations = 20 + const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row. + const insertRows = 4 + + var sqlBlobs = [2]string{ + strings.Repeat(blob, blobSize/len(blob)), + strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + for i := 0; i < insertRows; i++ { + dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) + if err != nil { + t.Fatal(err) + } + + var b int + var raw sql.RawBytes + for rows.Next() { + if err := rows.Scan(&b, &raw); err != nil { + t.Fatal(err) + } + + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) + + if before != after { + t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + } + } + rows.Close() + }() + } + }) +} + +var _ driver.DriverContext = &MySQLDriver{} + +type dialCtxKey struct{} + +func TestConnectorObeysDialTimeouts(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + if !ctx.Value(dialCtxKey{}).(bool) { + return nil, fmt.Errorf("test error: query context is not propagated to our dialer") + } + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + ctx := context.WithValue(context.Background(), dialCtxKey{}, true) + + _, err = db.ExecContext(ctx, "DO 1") + if err != nil { + t.Fatal(err) + } +} + +func configForTests(t *testing.T) *Config { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mycnf := NewConfig() + mycnf.User = user + mycnf.Passwd = pass + mycnf.Addr = addr + mycnf.Net = prot + mycnf.DBName = dbname + return mycnf +} + +func TestNewConnector(t *testing.T) { + mycnf := configForTests(t) + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +type slowConnection struct { + net.Conn + slowdown time.Duration +} + +func (sc *slowConnection) Read(b []byte) (int, error) { + time.Sleep(sc.slowdown) + return sc.Conn.Read(b) +} + +type connectorHijack struct { + driver.Connector + connErr error +} + +func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + conn, cw.connErr = cw.Connector.Connect(ctx) + return conn, cw.connErr +} + +func TestConnectorTimeoutsDuringOpen(t *testing.T) { + RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, prot, addr) + if err != nil { + return nil, err + } + return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil + }) + + mycnf := configForTests(t) + mycnf.Net = "slowconn" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + hijack := &connectorHijack{Connector: conn} + + db := sql.OpenDB(hijack) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = db.ExecContext(ctx, "DO 1") + if err != context.DeadlineExceeded { + t.Fatalf("ExecContext should have timed out") + } + if hijack.connErr != context.DeadlineExceeded { + t.Fatalf("(*Connector).Connect should have timed out") + } +} + +// A connection which can only be closed. +type dummyConnection struct { + net.Conn + closed bool +} + +func (d *dummyConnection) Close() error { + d.closed = true + return nil +} + +func TestConnectorTimeoutsWatchCancel(t *testing.T) { + var ( + cancel func() // Used to cancel the context just after connecting. + created *dummyConnection // The created connection. + ) + + RegisterDialContext("TestConnectorTimeoutsWatchCancel", func(ctx context.Context, addr string) (net.Conn, error) { + // Canceling at this time triggers the watchCancel error branch in Connect(). + cancel() + created = &dummyConnection{} + return created, nil + }) + + mycnf := NewConfig() + mycnf.User = "root" + mycnf.Addr = "foo" + mycnf.Net = "TestConnectorTimeoutsWatchCancel" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + + if _, err := db.Conn(ctx); err != context.Canceled { + t.Errorf("got %v, want context.Canceled", err) + } + + if created == nil { + t.Fatal("no connection created") + } + if !created.closed { + t.Errorf("connection not closed") + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go new file mode 100644 index 0000000..75c8c24 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -0,0 +1,560 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "errors" + "fmt" + "math/big" + "net" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +var ( + errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") +) + +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE + AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowNativePasswords bool // Allows the native password authentication method + AllowOldPasswords bool // Allows the old insecure password method + CheckConnLiveness bool // Check connections for liveness before using them + ClientFoundRows bool // Return number of matching rows instead of rows changed + ColumnsWithAlias bool // Prepend table alias to column names + InterpolateParams bool // Interpolate placeholders into query string + MultiStatements bool // Allow multiple statements in one query + ParseTime bool // Parse time values to time.Time + RejectReadOnly bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, + AllowNativePasswords: true, + CheckConnLiveness: true, + } +} + +func (cfg *Config) Clone() *Config { + cp := *cfg + if cp.tls != nil { + cp.tls = cfg.tls.Clone() + } + if len(cp.Params) > 0 { + cp.Params = make(map[string]string, len(cfg.Params)) + for k, v := range cfg.Params { + cp.Params[k] = v + } + } + if cfg.pubKey != nil { + cp.pubKey = &rsa.PublicKey{ + N: new(big.Int).Set(cfg.pubKey.N), + E: cfg.pubKey.E, + } + } + return &cp +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + switch cfg.TLSConfig { + case "false", "": + // don't set anything + case "true": + cfg.tls = &tls.Config{} + case "skip-verify", "preferred": + cfg.tls = &tls.Config{InsecureSkipVerify: true} + default: + cfg.tls = getTLSConfigClone(cfg.TLSConfig) + if cfg.tls == nil { + return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + } + } + + if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + + if cfg.ServerPubKey != "" { + cfg.pubKey = getServerPubKey(cfg.ServerPubKey) + if cfg.pubKey == nil { + return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey) + } + } + + return nil +} + +func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) { + buf.Grow(1 + len(name) + 1 + len(value)) + if !*hasParam { + *hasParam = true + buf.WriteByte('?') + } else { + buf.WriteByte('&') + } + buf.WriteString(name) + buf.WriteByte('=') + buf.WriteString(value) +} + +// FormatDSN formats the given Config into a DSN string which can be passed to +// the driver. +func (cfg *Config) FormatDSN() string { + var buf bytes.Buffer + + // [username[:password]@] + if len(cfg.User) > 0 { + buf.WriteString(cfg.User) + if len(cfg.Passwd) > 0 { + buf.WriteByte(':') + buf.WriteString(cfg.Passwd) + } + buf.WriteByte('@') + } + + // [protocol[(address)]] + if len(cfg.Net) > 0 { + buf.WriteString(cfg.Net) + if len(cfg.Addr) > 0 { + buf.WriteByte('(') + buf.WriteString(cfg.Addr) + buf.WriteByte(')') + } + } + + // /dbname + buf.WriteByte('/') + buf.WriteString(cfg.DBName) + + // [?param1=value1&...¶mN=valueN] + hasParam := false + + if cfg.AllowAllFiles { + hasParam = true + buf.WriteString("?allowAllFiles=true") + } + + if cfg.AllowCleartextPasswords { + writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true") + } + + if !cfg.AllowNativePasswords { + writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false") + } + + if cfg.AllowOldPasswords { + writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true") + } + + if !cfg.CheckConnLiveness { + writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false") + } + + if cfg.ClientFoundRows { + writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") + } + + if col := cfg.Collation; col != defaultCollation && len(col) > 0 { + writeDSNParam(&buf, &hasParam, "collation", col) + } + + if cfg.ColumnsWithAlias { + writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") + } + + if cfg.InterpolateParams { + writeDSNParam(&buf, &hasParam, "interpolateParams", "true") + } + + if cfg.Loc != time.UTC && cfg.Loc != nil { + writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String())) + } + + if cfg.MultiStatements { + writeDSNParam(&buf, &hasParam, "multiStatements", "true") + } + + if cfg.ParseTime { + writeDSNParam(&buf, &hasParam, "parseTime", "true") + } + + if cfg.ReadTimeout > 0 { + writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String()) + } + + if cfg.RejectReadOnly { + writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true") + } + + if len(cfg.ServerPubKey) > 0 { + writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey)) + } + + if cfg.Timeout > 0 { + writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String()) + } + + if len(cfg.TLSConfig) > 0 { + writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig)) + } + + if cfg.WriteTimeout > 0 { + writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String()) + } + + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { + writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket)) + } + + // other params + if cfg.Params != nil { + var params []string + for param := range cfg.Params { + params = append(params, param) + } + sort.Strings(params) + for _, param := range params { + writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param])) + } + } + + return buf.String() +} + +// ParseDSN parses the DSN string to a Config +func ParseDSN(dsn string) (cfg *Config, err error) { + // New config with some default values + cfg = NewConfig() + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.Passwd = dsn[k+1 : j] + break + } + } + cfg.User = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.Addr = dsn[k+1 : i-1] + break + } + } + cfg.Net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.DBName = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + if err = cfg.normalize(); err != nil { + return nil, err + } + return +} + +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *Config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + // Disable INFILE whitelist / enable all files + case "allowAllFiles": + var isBool bool + cfg.AllowAllFiles, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use cleartext authentication mode (MySQL 5.5.10+) + case "allowCleartextPasswords": + var isBool bool + cfg.AllowCleartextPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use native password authentication + case "allowNativePasswords": + var isBool bool + cfg.AllowNativePasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.AllowOldPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Check connections for Liveness before using them + case "checkConnLiveness": + var isBool bool + cfg.CheckConnLiveness, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Switch "rowsAffected" mode + case "clientFoundRows": + var isBool bool + cfg.ClientFoundRows, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Collation + case "collation": + cfg.Collation = value + break + + case "columnsWithAlias": + var isBool bool + cfg.ColumnsWithAlias, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Compression + case "compress": + return errors.New("compression not implemented yet") + + // Enable client side placeholder substitution + case "interpolateParams": + var isBool bool + cfg.InterpolateParams, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Time Location + case "loc": + if value, err = url.QueryUnescape(value); err != nil { + return + } + cfg.Loc, err = time.LoadLocation(value) + if err != nil { + return + } + + // multiple statements in one query + case "multiStatements": + var isBool bool + cfg.MultiStatements, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // time.Time parsing + case "parseTime": + var isBool bool + cfg.ParseTime, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // I/O read Timeout + case "readTimeout": + cfg.ReadTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // Reject read-only connections + case "rejectReadOnly": + var isBool bool + cfg.RejectReadOnly, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + cfg.ServerPubKey = name + + // Strict mode + case "strict": + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + + // Dial Timeout + case "timeout": + cfg.Timeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // TLS-Encryption + case "tls": + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.TLSConfig = "true" + } else { + cfg.TLSConfig = "false" + } + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { + cfg.TLSConfig = vl + } else { + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for TLS config name: %v", err) + } + cfg.TLSConfig = name + } + + // I/O write Timeout + case "writeTimeout": + cfg.WriteTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + case "maxAllowedPacket": + cfg.MaxAllowedPacket, err = strconv.Atoi(value) + if err != nil { + return + } + default: + // lazy init + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + } + + return +} + +func ensureHavePort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "3306") + } + return addr +} diff --git a/vendor/github.com/go-sql-driver/mysql/dsn_test.go b/vendor/github.com/go-sql-driver/mysql/dsn_test.go new file mode 100644 index 0000000..89815b3 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/dsn_test.go @@ -0,0 +1,415 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/tls" + "fmt" + "net/url" + "reflect" + "testing" + "time" +) + +var testDSNs = []struct { + in string + out *Config +}{{ + "username:password@protocol(address)/dbname?param=value", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, +}, { + "user@unix(/path/to/socket)/dbname?charset=utf8", + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, +}, { + "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, +}, { + "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, +}, { + "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "@/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "user:p@/ssword@/", + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "unix/?arg=%2Fsome%2Fpath.ext", + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "tcp(127.0.0.1)/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "tcp(de:ad:be:ef::ca:fe)/dbname", + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, +} + +func TestDSNParser(t *testing.T) { + for i, tst := range testDSNs { + cfg, err := ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + } + + // pointer not static + cfg.tls = nil + + if !reflect.DeepEqual(cfg, tst.out) { + t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) + } + } +} + +func TestDSNParserInvalid(t *testing.T) { + var invalidDSNs = []string{ + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + "net()/", // unknown default addr + //"/dbname?arg=/some/unescaped/path", + } + + for i, tst := range invalidDSNs { + if _, err := ParseDSN(tst); err == nil { + t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) + } + } +} + +func TestDSNReformat(t *testing.T) { + for i, tst := range testDSNs { + dsn1 := tst.in + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Error(err.Error()) + continue + } + cfg1.tls = nil // pointer not static + res1 := fmt.Sprintf("%+v", cfg1) + + dsn2 := cfg1.FormatDSN() + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Error(err.Error()) + continue + } + cfg2.tls = nil // pointer not static + res2 := fmt.Sprintf("%+v", cfg2) + + if res1 != res2 { + t.Errorf("%d. %q does not match %q", i, res2, res1) + } + } +} + +func TestDSNServerPubKey(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + tst := baseDSN + "testKey" + cfg, err := ParseDSN(tst) + if err != nil { + t.Error(err.Error()) + } + + if cfg.ServerPubKey != "testKey" { + t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey) + } + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } + + // Key is missing + tst = baseDSN + "invalid_name" + cfg, err = ParseDSN(tst) + if err == nil { + t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } +} + +func TestDSNServerPubKeyQueryEscape(t *testing.T) { + const name = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name) + + RegisterServerPubKey(name, testPubKeyRSA) + defer DeregisterServerPubKey(name) + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } +} + +func TestDSNWithCustomTLS(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?tls=" + tlsCfg := tls.Config{} + + RegisterTLSConfig("utils_test", &tlsCfg) + defer DeregisterTLSConfig("utils_test") + + // Custom TLS is missing + tst := baseDSN + "invalid_tls" + cfg, err := ParseDSN(tst) + if err == nil { + t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } + + tst = baseDSN + "utils_test" + + // Custom TLS with a server name + name := "foohost" + tlsCfg.ServerName = name + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) + } + + // Custom TLS without a server name + name = "localhost" + tlsCfg.ServerName = "" + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) + } else if tlsCfg.ServerName != "" { + t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) + } +} + +func TestDSNTLSConfig(t *testing.T) { + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true" + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + dsn = "tcp(example.com)/?tls=true" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } +} + +func TestDSNWithCustomTLSQueryEscape(t *testing.T) { + const configKey = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) + name := "foohost" + tlsCfg := tls.Config{ServerName: name} + + RegisterTLSConfig(configKey, &tlsCfg) + defer DeregisterTLSConfig(configKey) + + cfg, err := ParseDSN(dsn) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) + } +} + +func TestDSNUnsafeCollation(t *testing.T) { + _, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") + if err != errInvalidDSNUnsafeCollation { + t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } +} + +func TestParamsAreSorted(t *testing.T) { + expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo" + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.InterpolateParams = true + cfg.Params = map[string]string{ + "quux": "loo", + "foobar": "baz", + } + actual := cfg.FormatDSN() + if actual != expected { + t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual) + } +} + +func TestCloneConfig(t *testing.T) { + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey" + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatal(err.Error()) + } + + cfg2 := cfg.Clone() + if cfg == cfg2 { + t.Errorf("Config.Clone did not create a separate config struct") + } + + if cfg2.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + cfg2.tls.ServerName = "example2.com" + if cfg.tls.ServerName == cfg2.tls.ServerName { + t.Errorf("changed cfg.tls.Server name should not propagate to original Config") + } + + if _, ok := cfg2.Params["foobar"]; !ok { + t.Errorf("cloned Config is missing custom params") + } + + delete(cfg2.Params, "foobar") + + if _, ok := cfg.Params["foobar"]; !ok { + t.Errorf("custom params in cloned Config should not propagate to original Config") + } + + if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) { + t.Errorf("public key in Config should be identical") + } +} + +func TestNormalizeTLSConfig(t *testing.T) { + tt := []struct { + tlsConfig string + want *tls.Config + }{ + {"", nil}, + {"false", nil}, + {"true", &tls.Config{ServerName: "myserver"}}, + {"skip-verify", &tls.Config{InsecureSkipVerify: true}}, + {"preferred", &tls.Config{InsecureSkipVerify: true}}, + {"test_tls_config", &tls.Config{ServerName: "myServerName"}}, + } + + RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"}) + defer func() { DeregisterTLSConfig("test_tls_config") }() + + for _, tc := range tt { + t.Run(tc.tlsConfig, func(t *testing.T) { + cfg := &Config{ + Addr: "myserver:3306", + TLSConfig: tc.tlsConfig, + } + + cfg.normalize() + + if cfg.tls == nil { + if tc.want != nil { + t.Fatal("wanted a tls config but got nil instead") + } + return + } + + if cfg.tls.ServerName != tc.want.ServerName { + t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')", + tc.want.ServerName, cfg.tls.ServerName) + } + if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify { + t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)", + tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify) + } + }) + } +} + +func BenchmarkParseDSN(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tst := range testDSNs { + if _, err := ParseDSN(tst.in); err != nil { + b.Error(err.Error()) + } + } + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/errors.go b/vendor/github.com/go-sql-driver/mysql/errors.go index 44cf30d..760782f 100644 --- a/vendor/github.com/go-sql-driver/mysql/errors.go +++ b/vendor/github.com/go-sql-driver/mysql/errors.go @@ -9,30 +9,35 @@ package mysql import ( - "database/sql/driver" "errors" "fmt" - "io" "log" "os" ) // Various errors the driver might return. Can change between driver versions. var ( - ErrInvalidConn = errors.New("Invalid Connection") - ErrMalformPkt = errors.New("Malformed Packet") - ErrNoTLS = errors.New("TLS encryption requested but server does not support TLS") - ErrOldPassword = errors.New("This user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") - ErrCleartextPassword = errors.New("This user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN.") - ErrUnknownPlugin = errors.New("The authentication plugin is not supported.") - ErrOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+") - ErrPktSync = errors.New("Commands out of sync. You can't run this command now") - ErrPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?") - ErrPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.") - ErrBusyBuffer = errors.New("Busy buffer") + ErrInvalidConn = errors.New("invalid connection") + ErrMalformPkt = errors.New("malformed packet") + ErrNoTLS = errors.New("TLS requested but server does not support TLS") + ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") + ErrNativePassword = errors.New("this user requires mysql native password authentication.") + ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") + ErrUnknownPlugin = errors.New("this authentication plugin is not supported") + ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") + ErrPktSync = errors.New("commands out of sync. You can't run this command now") + ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrBusyBuffer = errors.New("busy buffer") + + // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. + // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn + // to trigger a resend. + // See https://github.com/go-sql-driver/mysql/pull/302 + errBadConnNoWrite = errors.New("bad connection") ) -var errLog Logger = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) +var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) // Logger is used to log critical error messages. type Logger interface { @@ -58,74 +63,3 @@ type MySQLError struct { func (me *MySQLError) Error() string { return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } - -// MySQLWarnings is an error type which represents a group of one or more MySQL -// warnings -type MySQLWarnings []MySQLWarning - -func (mws MySQLWarnings) Error() string { - var msg string - for i, warning := range mws { - if i > 0 { - msg += "\r\n" - } - msg += fmt.Sprintf( - "%s %s: %s", - warning.Level, - warning.Code, - warning.Message, - ) - } - return msg -} - -// MySQLWarning is an error type which represents a single MySQL warning. -// Warnings are returned in groups only. See MySQLWarnings -type MySQLWarning struct { - Level string - Code string - Message string -} - -func (mc *mysqlConn) getWarnings() (err error) { - rows, err := mc.Query("SHOW WARNINGS", nil) - if err != nil { - return - } - - var warnings = MySQLWarnings{} - var values = make([]driver.Value, 3) - - for { - err = rows.Next(values) - switch err { - case nil: - warning := MySQLWarning{} - - if raw, ok := values[0].([]byte); ok { - warning.Level = string(raw) - } else { - warning.Level = fmt.Sprintf("%s", values[0]) - } - if raw, ok := values[1].([]byte); ok { - warning.Code = string(raw) - } else { - warning.Code = fmt.Sprintf("%s", values[1]) - } - if raw, ok := values[2].([]byte); ok { - warning.Message = string(raw) - } else { - warning.Message = fmt.Sprintf("%s", values[0]) - } - - warnings = append(warnings, warning) - - case io.EOF: - return warnings - - default: - rows.Close() - return - } - } -} diff --git a/vendor/github.com/go-sql-driver/mysql/errors_test.go b/vendor/github.com/go-sql-driver/mysql/errors_test.go new file mode 100644 index 0000000..96f9126 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/errors_test.go @@ -0,0 +1,42 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "log" + "testing" +) + +func TestErrorsSetLogger(t *testing.T) { + previous := errLog + defer func() { + errLog = previous + }() + + // set up logger + const expected = "prefix: test\n" + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + logger := log.New(buffer, "prefix: ", 0) + + // print + SetLogger(logger) + errLog.Print("test") + + // check result + if actual := buffer.String(); actual != expected { + t.Errorf("expected %q, got %q", expected, actual) + } +} + +func TestErrorsStrictIgnoreNotes(t *testing.T) { + runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) { + dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") + }) +} diff --git a/vendor/github.com/go-sql-driver/mysql/fields.go b/vendor/github.com/go-sql-driver/mysql/fields.go new file mode 100644 index 0000000..e1e2ece --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/fields.go @@ -0,0 +1,194 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "reflect" +) + +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte + charSet uint8 +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/go.mod b/vendor/github.com/go-sql-driver/mysql/go.mod new file mode 100644 index 0000000..fffbf6a --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/go.mod @@ -0,0 +1,3 @@ +module github.com/go-sql-driver/mysql + +go 1.10 diff --git a/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/go-sql-driver/mysql/infile.go index 84c53a9..273cb0b 100644 --- a/vendor/github.com/go-sql-driver/mysql/infile.go +++ b/vendor/github.com/go-sql-driver/mysql/infile.go @@ -96,6 +96,10 @@ func deferredClose(err *error, closer io.Closer) { func (mc *mysqlConn) handleInFileRequest(name string) (err error) { var rdr io.Reader var data []byte + packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + if mc.maxWriteSize < packetSize { + packetSize = mc.maxWriteSize + } if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader // The server might return an an absolute path. See issue #355. @@ -108,8 +112,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { if inMap { rdr = handler() if rdr != nil { - data = make([]byte, 4+mc.maxWriteSize) - if cl, ok := rdr.(io.Closer); ok { defer deferredClose(&err, cl) } @@ -124,7 +126,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { fileRegisterLock.RLock() fr := fileRegister[name] fileRegisterLock.RUnlock() - if mc.cfg.allowAllFiles || fr { + if mc.cfg.AllowAllFiles || fr { var file *os.File var fi os.FileInfo @@ -134,22 +136,20 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // get file size if fi, err = file.Stat(); err == nil { rdr = file - if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize { - data = make([]byte, 4+fileSize) - } else if fileSize <= mc.maxPacketAllowed { - data = make([]byte, 4+mc.maxWriteSize) - } else { - err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed) + if fileSize := int(fi.Size()); fileSize < packetSize { + packetSize = fileSize } } } } else { - err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name) + err = fmt.Errorf("local file '%s' is not registered", name) } } // send content packets - if err == nil { + // if packetSize == 0, the Reader contains no data + if err == nil && packetSize > 0 { + data := make([]byte, 4+packetSize) var n int for err == nil { n, err = rdr.Read(data[4:]) @@ -175,8 +175,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // read OK packet if err == nil { return mc.readResultOK() - } else { - mc.readPacket() } + + mc.readPacket() return err } diff --git a/vendor/github.com/go-sql-driver/mysql/nulltime.go b/vendor/github.com/go-sql-driver/mysql/nulltime.go new file mode 100644 index 0000000..afa8a89 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/nulltime.go @@ -0,0 +1,50 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "fmt" + "time" +) + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = parseDateTime(string(v), time.UTC) + nt.Valid = (err == nil) + return + case string: + nt.Time, err = parseDateTime(v, time.UTC) + nt.Valid = (err == nil) + return + } + + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/nulltime_go113.go b/vendor/github.com/go-sql-driver/mysql/nulltime_go113.go new file mode 100644 index 0000000..c392594 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/nulltime_go113.go @@ -0,0 +1,31 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.13 + +package mysql + +import ( + "database/sql" +) + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +type NullTime sql.NullTime diff --git a/vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go b/vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go new file mode 100644 index 0000000..86d159d --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go @@ -0,0 +1,34 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !go1.13 + +package mysql + +import ( + "time" +) + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} diff --git a/vendor/github.com/go-sql-driver/mysql/nulltime_test.go b/vendor/github.com/go-sql-driver/mysql/nulltime_test.go new file mode 100644 index 0000000..a14ec06 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/nulltime_test.go @@ -0,0 +1,62 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "database/sql/driver" + "testing" + "time" +) + +var ( + // Check implementation of interfaces + _ driver.Valuer = NullTime{} + _ sql.Scanner = (*NullTime)(nil) +) + +func TestScanNullTime(t *testing.T) { + var scanTests = []struct { + in interface{} + error bool + valid bool + time time.Time + }{ + {tDate, false, true, tDate}, + {sDate, false, true, tDate}, + {[]byte(sDate), false, true, tDate}, + {tDateTime, false, true, tDateTime}, + {sDateTime, false, true, tDateTime}, + {[]byte(sDateTime), false, true, tDateTime}, + {tDate0, false, true, tDate0}, + {sDate0, false, true, tDate0}, + {[]byte(sDate0), false, true, tDate0}, + {sDateTime0, false, true, tDate0}, + {[]byte(sDateTime0), false, true, tDate0}, + {"", true, false, tDate0}, + {"1234", true, false, tDate0}, + {0, true, false, tDate0}, + } + + var nt = NullTime{} + var err error + + for _, tst := range scanTests { + err = nt.Scan(tst.in) + if (err != nil) != tst.error { + t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil)) + } + if nt.Valid != tst.valid { + t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid) + } + if nt.Time != tst.time { + t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time) + } + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index 76cb7c8..82ad7a2 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" "math" @@ -24,55 +25,66 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { - var payload []byte + var prevData []byte for { - // Read packet header + // read packet header data, err := mc.buf.readNext(4) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } - // Packet Length [24 bit] + // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - if pktLen < 1 { - errLog.Print(ErrMalformPkt) - mc.Close() - return nil, driver.ErrBadConn - } - - // Check Packet Sync [8 bit] + // check packet sync [8 bit] if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul - } else { - return nil, ErrPktSync } + return nil, ErrPktSync } mc.sequence++ - // Read packet body [pktLen bytes] + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)-1 bytes long + if pktLen == 0 { + // there was no previous packet + if prevData == nil { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, ErrInvalidConn + } + + return prevData, nil + } + + // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } - isLastPacket := (pktLen < maxPacketSize) + // return data if this was the last packet + if pktLen < maxPacketSize { + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } - // Zero allocations for non-splitting packets - if isLastPacket && payload == nil { - return data, nil + return append(prevData, data...), nil } - payload = append(payload, data...) - - if isLastPacket { - return payload, nil - } + prevData = append(prevData, data...) } } @@ -80,10 +92,39 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 - if pktLen > mc.maxPacketAllowed { + if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.reset { + mc.reset = false + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + var err error + // If this connection has a ReadTimeout which we've been setting on + // reads, reset it to its default value before we attempt a non-blocking + // read, otherwise the scheduler will just time us out before we can read + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Time{}) + } + if err == nil && mc.cfg.CheckConnLiveness { + err = connCheck(conn) + } + if err != nil { + errLog.Print("closing bad idle connection: ", err) + mc.Close() + return driver.ErrBadConn + } + } + for { var size int if pktLen >= maxPacketSize { @@ -100,6 +141,12 @@ func (mc *mysqlConn) writePacket(data []byte) error { data[3] = mc.sequence // Write packet + if mc.writeTimeout > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + return err + } + } + n, err := mc.netConn.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ @@ -113,34 +160,48 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) + mc.cleanup() errLog.Print(ErrMalformPkt) } else { + if cerr := mc.canceled.Value(); cerr != nil { + return cerr + } + if n == 0 && pktLen == len(data)-4 { + // only for the first loop iteration when nothing was written yet + return errBadConnNoWrite + } + mc.cleanup() errLog.Print(err) } - return driver.ErrBadConn + return ErrInvalidConn } } /****************************************************************************** -* Initialisation Process * +* Initialization Process * ******************************************************************************/ // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { + data, err = mc.readPacket() if err != nil { - return nil, err + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, "", driver.ErrBadConn + } + return } if data[0] == iERR { - return nil, mc.handleErrorPacket(data) + return nil, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, fmt.Errorf( - "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", + return nil, "", fmt.Errorf( + "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, ) @@ -151,7 +212,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] - cipher := data[pos : pos+8] + authData := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -159,10 +220,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { - return nil, ErrOldProtocol + return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, ErrNoTLS + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } } pos += 2 @@ -186,32 +251,32 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. - cipher = append(cipher, data[pos:pos+12]...) + authData = append(authData, data[pos:pos+12]...) + pos += 13 - // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise - // - //if data[len(data)-1] == 0 { - // return - //} - //return ErrMalformPkt + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + plugin = string(data[pos : pos+end]) + } else { + plugin = string(data[pos:]) + } // make a memory safe copy of the cipher slice var b [20]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -219,9 +284,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientTransactions | clientLocalFiles | clientPluginAuth | + clientMultiResults | mc.flags&clientLongFlag - if mc.cfg.clientFoundRows { + if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } @@ -230,23 +296,34 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientSSL } - // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd)) + if mc.cfg.MultiStatements { + clientFlags |= clientMultiStatements + } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) + 21 + 1 + // encode length of the auth plugin data + var authRespLEIBuf [9]byte + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + clientFlags |= clientPluginAuthLenEncClientData + } + + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name - if n := len(mc.cfg.dbname); n > 0 { + if n := len(mc.cfg.DBName); n > 0 { clientFlags |= clientConnectWithDB pktLen += n + 1 } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // ClientFlags [32 bit] @@ -262,7 +339,14 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[11] = 0x00 // Charset [1 byte] - data[12] = mc.cfg.collation + var found bool + data[12], found = collations[mc.cfg.Collation] + if !found { + // Note possibility for false negatives: + // could be triggered although the collation is valid if the + // collations map does not contain entries the server supports. + return errors.New("unknown collation") + } // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -277,8 +361,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { if err := tlsConn.Handshake(); err != nil { return err } + mc.rawConn = mc.netConn mc.netConn = tlsConn - mc.buf.rd = tlsConn + mc.buf.nc = tlsConn } // Filler [23 bytes] (all 0x00) @@ -288,69 +373,43 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // User [null terminated string] - if len(mc.cfg.user) > 0 { - pos += copy(data[pos:], mc.cfg.user) + if len(mc.cfg.User) > 0 { + pos += copy(data[pos:], mc.cfg.User) } data[pos] = 0x00 pos++ - // ScrambleBuffer [length encoded integer] - data[pos] = byte(len(scrambleBuff)) - pos += 1 + copy(data[pos+1:], scrambleBuff) + // Auth Data [length encoded integer] + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], authResp) // Databasename [null terminated string] - if len(mc.cfg.dbname) > 0 { - pos += copy(data[pos:], mc.cfg.dbname) + if len(mc.cfg.DBName) > 0 { + pos += copy(data[pos:], mc.cfg.DBName) data[pos] = 0x00 pos++ } - // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], plugin) data[pos] = 0x00 + pos++ // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(data[:pos]) } -// Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { - // User password - scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { + pktLen := 4 + len(authData) + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } - // Add the scrambled password [null terminated string] - copy(data[4:], scrambleBuff) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Client clear text authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { - // Calculate the packet length and add a tailing 0 - pktLen := len(mc.cfg.passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn - } - - // Add the clear password [null terminated string] - copy(data[4:], mc.cfg.passwd) - data[4+pktLen-1] = 0x00 - + // Add the auth data [EOF] + copy(data[4:], authData) return mc.writePacket(data) } @@ -362,11 +421,11 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // Add command byte @@ -381,11 +440,11 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // Add command byte @@ -402,11 +461,11 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // Add command byte @@ -426,37 +485,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { * Result Packets * ******************************************************************************/ +func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { + data, err := mc.readPacket() + if err != nil { + return nil, "", err + } + + // packet indicator + switch data[0] { + + case iOK: + return nil, "", mc.handleOkPacket(data) + + case iAuthMoreData: + return data[1:], "", err + + case iEOF: + if len(data) == 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, "mysql_old_password", nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", ErrMalformPkt + } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", mc.handleErrorPacket(data) + } +} + // Returns error if Packet is not an 'Result OK'-Packet func (mc *mysqlConn) readResultOK() error { data, err := mc.readPacket() - if err == nil { - // packet indicator - switch data[0] { - - case iOK: - return mc.handleOkPacket(data) - - case iEOF: - if len(data) > 1 { - plugin := string(data[1:bytes.IndexByte(data, 0x00)]) - if plugin == "mysql_old_password" { - // using old_passwords - return ErrOldPassword - } else if plugin == "mysql_clear_password" { - // using clear text password - return ErrCleartextPassword - } else { - return ErrUnknownPlugin - } - } else { - return ErrOldPassword - } - - default: // Error otherwise - return mc.handleErrorPacket(data) - } + if err != nil { + return err } - return err + + if data[0] == iOK { + return mc.handleOkPacket(data) + } + return mc.handleErrorPacket(data) } // Result Set Header Packet @@ -499,6 +571,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // Error Number [16 bit uint] errno := binary.LittleEndian.Uint16(data[1:3]) + // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { + // Oops; we are connected to a read-only connection, and won't be able + // to issue any write statements. Since RejectReadOnly is configured, + // we throw away this connection hoping this one would have write + // permission. This is specifically for a possible race condition + // during failover (e.g. on AWS Aurora). See README.md for more. + // + // We explicitly close the connection before returning + // driver.ErrBadConn to ensure that `database/sql` purges this + // connection and initiates a new one for next statement next time. + mc.Close() + return driver.ErrBadConn + } + pos := 3 // SQL State [optional: # + 5bytes string] @@ -514,6 +602,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { } } +func readStatus(b []byte) statusFlag { + return statusFlag(b[0]) | statusFlag(b[1])<<8 +} + // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { @@ -528,18 +620,14 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] - mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8 - - // warning count [2 bytes] - if !mc.strict { - return nil - } else { - pos := 1 + n + m + 2 - if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - return mc.getWarnings() - } + mc.status = readStatus(data[1+n+m : 1+n+m+2]) + if mc.status&statusMoreResultsExists != 0 { return nil } + + // warning count [2 bytes] + + return nil } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -558,7 +646,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if i == count { return columns, nil } - return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) + return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) } // Catalog @@ -575,7 +663,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { pos += n // Table [len coded string] - if mc.cfg.columnsWithAlias { + if mc.cfg.ColumnsWithAlias { tableName, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err @@ -610,14 +698,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } + pos += n // Filler [uint8] + pos++ + // Charset [charset, collation uint8] + columns[i].charSet = data[pos] + pos += 2 + // Length [uint32] - pos += n + 1 + 2 + 4 + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 // Field type [uint8] - columns[i].fieldType = data[pos] + columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] @@ -640,6 +735,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc + if rows.rs.done { + return io.EOF + } + data, err := mc.readPacket() if err != nil { return err @@ -647,7 +746,12 @@ func (rows *textRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { - rows.mc = nil + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil + } return io.EOF } if data[0] == iERR { @@ -669,12 +773,12 @@ func (rows *textRows) readRow(dest []driver.Value) error { if !mc.parseTime { continue } else { - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( string(dest[i].([]byte)), - mc.cfg.loc, + mc.cfg.Loc, ) if err == nil { continue @@ -699,12 +803,19 @@ func (rows *textRows) readRow(dest []driver.Value) error { func (mc *mysqlConn) readUntilEOF() error { for { data, err := mc.readPacket() - - // No Err and no EOF Packet - if err == nil && data[0] != iEOF { - continue + if err != nil { + return err + } + + switch data[0] { + case iERR: + return mc.handleErrorPacket(data) + case iEOF: + if len(data) == 5 { + mc.status = readStatus(data[3:]) + } + return nil } - return err // Err or EOF } } @@ -734,22 +845,15 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // Reserved [8 bit] // Warning count [16 bit uint] - if !stmt.mc.strict { - return columnCount, nil - } else { - // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { - return columnCount, stmt.mc.getWarnings() - } - return columnCount, nil - } + + return columnCount, nil } return 0, err } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { - maxLen := stmt.mc.maxPacketAllowed - 1 + maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen // After the header (bytes 0-3) follows before the data: @@ -758,7 +862,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // 2 bytes paramID const dataOffset = 1 + 4 + 2 - // Can not use the write buffer since + // Cannot use the write buffer since // a) the buffer is too small // b) it is in use data := make([]byte, 4+1+4+2+len(arg)) @@ -804,7 +908,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( - "Arguments count mismatch (Got: %d Has: %d)", + "argument count mismatch (got: %d; has: %d)", len(args), stmt.paramCount, ) @@ -813,20 +917,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc + // Determine threshold dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + // Reset packet-sequence mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. } - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // command [1 byte] @@ -851,7 +963,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos := minPktLen var nullMask []byte - if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here @@ -860,10 +972,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. pos += maskLen } else { nullMask = data[pos : pos+maskLen] - for i := 0; i < maskLen; i++ { + for i := range nullMask { nullMask[i] = 0 } pos += maskLen @@ -885,7 +998,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } @@ -893,7 +1006,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // cache types and values switch v := arg.(type) { case int64: - paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -908,8 +1021,24 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) } + case uint64: + paramTypes[i+i] = byte(fieldTypeLongLong) + paramTypes[i+i+1] = 0x80 // type is unsigned + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + case float64: - paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -925,7 +1054,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case bool: - paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { @@ -937,10 +1066,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case []byte: // Common case (non-nil value) first if v != nil { - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -955,14 +1084,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -974,23 +1103,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case time.Time: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - var val []byte + var a [64]byte + var b = a[:0] + if v.IsZero() { - val = []byte("0000-00-00") + b = append(b, "0000-00-00"...) } else { - val = []byte(v.In(mc.cfg.loc).Format(timeFormat)) + b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) } paramValues = appendLengthEncodedInteger(paramValues, - uint64(len(val)), + uint64(len(b)), ) - paramValues = append(paramValues, val...) + paramValues = append(paramValues, b...) default: - return fmt.Errorf("Can't convert type: %T", arg) + return fmt.Errorf("cannot convert type: %T", arg) } } @@ -998,7 +1129,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } } pos += len(paramValues) @@ -1008,6 +1142,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return mc.writePacket(data) } +func (mc *mysqlConn) discardResults() error { + for mc.status&statusMoreResultsExists != 0 { + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } + } + return nil +} + // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() @@ -1017,14 +1171,20 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - rows.mc = nil // EOF Packet if data[0] == iEOF && len(data) == 5 { + rows.mc.status = readStatus(data[3:]) + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil + } return io.EOF } + mc := rows.mc + rows.mc = nil // Error otherwise - return rows.mc.handleErrorPacket(data) + return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] @@ -1040,14 +1200,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } // Convert to byte-coded string - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1056,7 +1216,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeShort, fieldTypeYear: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1065,7 +1225,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeInt24, fieldTypeLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1074,7 +1234,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeLongLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1088,7 +1248,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeFloat: - dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 continue @@ -1101,7 +1261,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, - fieldTypeVarString, fieldTypeString, fieldTypeGeometry: + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: var isNull bool var n int dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) @@ -1128,41 +1288,41 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case isNull: dest[i] = nil continue - case rows.columns[i].fieldType == fieldTypeTime: + case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: dstlen = 8 + 1 + decimals default: return fmt.Errorf( - "MySQL protocol error, illegal decimals value %d", - rows.columns[i].decimals, + "protocol error, illegal decimals value %d", + rows.rs.columns[i].decimals, ) } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) case rows.mc.parseTime: - dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 - if rows.columns[i].fieldType == fieldTypeDate { + if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: dstlen = 19 + 1 + decimals default: return fmt.Errorf( - "MySQL protocol error, illegal decimals value %d", - rows.columns[i].decimals, + "protocol error, illegal decimals value %d", + rows.rs.columns[i].decimals, ) } } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) } if err == nil { @@ -1174,7 +1334,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } diff --git a/vendor/github.com/go-sql-driver/mysql/packets_test.go b/vendor/github.com/go-sql-driver/mysql/packets_test.go new file mode 100644 index 0000000..b61e4db --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/packets_test.go @@ -0,0 +1,336 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "errors" + "net" + "testing" + "time" +) + +var ( + errConnClosed = errors.New("connection is closed") + errConnTooManyReads = errors.New("too many reads") + errConnTooManyWrites = errors.New("too many writes") +) + +// struct to mock a net.Conn for testing purposes +type mockConn struct { + laddr net.Addr + raddr net.Addr + data []byte + written []byte + queuedReplies [][]byte + closed bool + read int + reads int + writes int + maxReads int + maxWrites int +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.reads++ + if m.maxReads > 0 && m.reads > m.maxReads { + return 0, errConnTooManyReads + } + + n = copy(b, m.data) + m.read += n + m.data = m.data[n:] + return +} +func (m *mockConn) Write(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.writes++ + if m.maxWrites > 0 && m.writes > m.maxWrites { + return 0, errConnTooManyWrites + } + + n = len(b) + m.written = append(m.written, b...) + + if n > 0 && len(m.queuedReplies) > 0 { + m.data = m.queuedReplies[0] + m.queuedReplies = m.queuedReplies[1:] + } + return +} +func (m *mockConn) Close() error { + m.closed = true + return nil +} +func (m *mockConn) LocalAddr() net.Addr { + return m.laddr +} +func (m *mockConn) RemoteAddr() net.Addr { + return m.raddr +} +func (m *mockConn) SetDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetReadDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// make sure mockConn implements the net.Conn interface +var _ net.Conn = new(mockConn) + +func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: NewConfig(), + netConn: conn, + closech: make(chan struct{}), + maxAllowedPacket: defaultMaxAllowedPacket, + sequence: sequence, + } + return conn, mc +} + +func TestReadPacketSingleByte(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 1 { + t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet)) + } + if packet[0] != 0xff { + t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0]) + } +} + +func TestReadPacketWrongSequenceID(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + // too low sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + mc.sequence = 1 + _, err := mc.readPacket() + if err != ErrPktSync { + t.Errorf("expected ErrPktSync, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // too high sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} + _, err = mc.readPacket() + if err != ErrPktSyncMul { + t.Errorf("expected ErrPktSyncMul, got %v", err) + } +} + +func TestReadPacketSplit(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + data := make([]byte, maxPacketSize*2+4*3) + const pkt2ofs = maxPacketSize + 4 + const pkt3ofs = 2 * (maxPacketSize + 4) + + // case 1: payload has length maxPacketSize + data = data[:pkt2ofs+4] + + // 1st packet has maxPacketSize length and sequence id 0 + // ff ff ff 00 ... + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + + // mark the payload start and end of 1st packet so that we can check if the + // content was correctly appended + data[4] = 0x11 + data[maxPacketSize+3] = 0x22 + + // 2nd packet has payload length 0 and squence id 1 + // 00 00 00 01 + data[pkt2ofs+3] = 0x01 + + conn.data = data + conn.maxReads = 3 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize-1] != 0x22 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1]) + } + + // case 2: payload has length which is a multiple of maxPacketSize + data = data[:cap(data)] + + // 2nd packet now has maxPacketSize length + data[pkt2ofs] = 0xff + data[pkt2ofs+1] = 0xff + data[pkt2ofs+2] = 0xff + + // mark the payload start and end of the 2nd packet + data[pkt2ofs+4] = 0x33 + data[pkt2ofs+maxPacketSize+3] = 0x44 + + // 3rd packet has payload length 0 and squence id 2 + // 00 00 00 02 + data[pkt3ofs+3] = 0x02 + + conn.data = data + conn.reads = 0 + conn.maxReads = 5 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 2*maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[2*maxPacketSize-1] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1]) + } + + // case 3: payload has a length larger maxPacketSize, which is not an exact + // multiple of it + data = data[:pkt2ofs+4+42] + data[pkt2ofs] = 0x2a + data[pkt2ofs+1] = 0x00 + data[pkt2ofs+2] = 0x00 + data[pkt2ofs+4+41] = 0x44 + + conn.data = data + conn.reads = 0 + conn.maxReads = 4 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize+42 { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize+41] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41]) + } +} + +func TestReadPacketFail(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + closech: make(chan struct{}), + } + + // illegal empty (stand-alone) packet + conn.data = []byte{0x00, 0x00, 0x00, 0x00} + conn.maxReads = 1 + _, err := mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read header + conn.closed = true + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.closed = false + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read body + conn.maxReads = 1 + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } +} + +// https://github.com/go-sql-driver/mysql/pull/801 +// not-NUL terminated plugin_name in init packet +func TestRegression801(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: new(Config), + sequence: 42, + closech: make(chan struct{}), + } + + conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, + 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77, + 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, + 112, 97, 115, 115, 119, 111, 114, 100} + conn.maxReads = 1 + + authData, pluginName, err := mc.readHandshakePacket() + if err != nil { + t.Fatalf("got error: %v", err) + } + + if pluginName != "mysql_native_password" { + t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) + } + + expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114, + 47, 85, 75, 109, 99, 51, 77, 50, 64} + if !bytes.Equal(authData, expectedAuthData) { + t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go index ba606e1..888bdb5 100644 --- a/vendor/github.com/go-sql-driver/mysql/rows.go +++ b/vendor/github.com/go-sql-driver/mysql/rows.go @@ -11,19 +11,20 @@ package mysql import ( "database/sql/driver" "io" + "math" + "reflect" ) -type mysqlField struct { - tableName string - name string - flags fieldFlag - fieldType byte - decimals byte +type resultSet struct { + columns []mysqlField + columnNames []string + done bool } type mysqlRows struct { - mc *mysqlConn - columns []mysqlField + mc *mysqlConn + rs resultSet + finish func() } type binaryRows struct { @@ -34,45 +35,163 @@ type textRows struct { mysqlRows } -type emptyRows struct{} - func (rows *mysqlRows) Columns() []string { - columns := make([]string, len(rows.columns)) - if rows.mc.cfg.columnsWithAlias { + if rows.rs.columnNames != nil { + return rows.rs.columnNames + } + + columns := make([]string, len(rows.rs.columns)) + if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { - if tableName := rows.columns[i].tableName; len(tableName) > 0 { - columns[i] = tableName + "." + rows.columns[i].name + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name } else { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } } else { for i := range columns { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } + + rows.rs.columnNames = columns return columns } -func (rows *mysqlRows) Close() error { +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + return rows.rs.columns[i].typeDatabaseName() +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + +func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + mc := rows.mc if mc == nil { return nil } - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err + } + + // flip the buffer for this connection if we need to drain it. + // note that for a successful query (i.e. one where rows.next() + // has been called until it returns false), `rows.mc` will be nil + // by the time the user calls `(*Rows).Close`, so we won't reach this + // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.flip() + + // Remove unread packets from stream + if !rows.rs.done { + err = mc.readUntilEOF() + } + if err == nil { + if err = mc.discardResults(); err != nil { + return err + } + } + + rows.mc = nil + return err +} + +func (rows *mysqlRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *mysqlRows) nextResultSet() (int, error) { + if rows.mc == nil { + return 0, io.EOF + } + if err := rows.mc.error(); err != nil { + return 0, err } // Remove unread packets from stream - err := mc.readUntilEOF() - rows.mc = nil + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return 0, err + } + rows.rs.done = true + } + + if !rows.HasNextResultSet() { + rows.mc = nil + return 0, io.EOF + } + rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() +} + +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } + + rows.rs.done = true + } +} + +func (rows *binaryRows) NextResultSet() error { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) return err } func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream @@ -81,10 +200,20 @@ func (rows *binaryRows) Next(dest []driver.Value) error { return io.EOF } +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + func (rows *textRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream @@ -92,15 +221,3 @@ func (rows *textRows) Next(dest []driver.Value) error { } return io.EOF } - -func (rows emptyRows) Columns() []string { - return nil -} - -func (rows emptyRows) Close() error { - return nil -} - -func (rows emptyRows) Next(dest []driver.Value) error { - return io.EOF -} diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index 6e869b3..f7e3709 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -11,20 +11,22 @@ package mysql import ( "database/sql/driver" "fmt" + "io" "reflect" - "strconv" ) type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int - columns []mysqlField // cached from the first query } func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.netConn == nil { - errLog.Print(ErrInvalidConn) + if stmt.mc == nil || stmt.mc.closed.IsSet() { + // driver.Stmt.Close can be called more than once, thus this function + // has to be idempotent. + // See also Issue #450 and golang/go#16019. + //errLog.Print(ErrInvalidConn) return driver.ErrBadConn } @@ -42,14 +44,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.netConn == nil { + if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc @@ -59,37 +61,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil { - if resLen > 0 { - // Columns - err = mc.readUntilEOF() - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - // Rows - err = mc.readUntilEOF() + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err } - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err } } - return nil, err + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.mc.netConn == nil { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { + if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc @@ -101,17 +111,18 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } rows := new(binaryRows) - rows.mc = mc if resLen > 0 { - // Columns - // If not cached, read them and cache them - if stmt.columns == nil { - rows.columns, err = mc.readColumns(resLen) - stmt.columns = rows.columns - } else { - rows.columns = stmt.columns - err = mc.readUntilEOF() + rows.mc = mc + rows.rs.columns, err = mc.readColumns(resLen) + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err } } @@ -120,31 +131,74 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { type converter struct{} +// ConvertValue mirrors the reference/default converter in database/sql/driver +// with _one_ exception. We support uint64 with their high bit and the default +// implementation does not. This function should be kept in sync with +// database/sql/driver defaultConverter.ConvertValue() except for that +// deliberate difference. func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if driver.IsValue(v) { return v, nil } + if vr, ok := v.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + } + rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil + } else { + return c.ConvertValue(rv.Elem().Interface()) } - return c.ConvertValue(rv.Elem().Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return int64(rv.Uint()), nil - case reflect.Uint64: - u64 := rv.Uint() - if u64 >= 1<<63 { - return strconv.FormatUint(u64, 10), nil - } - return int64(u64), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return rv.Uint(), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) + case reflect.String: + return rv.String(), nil } return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This is an exact copy of the same-named unexported function from the +// database/sql package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/vendor/github.com/go-sql-driver/mysql/statement_test.go b/vendor/github.com/go-sql-driver/mysql/statement_test.go new file mode 100644 index 0000000..4b9914f --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/statement_test.go @@ -0,0 +1,126 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "testing" +) + +func TestConvertDerivedString(t *testing.T) { + type derived string + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Derived string type not convertible", err) + } + + if output != "value" { + t.Fatalf("Derived string type not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedByteSlice(t *testing.T) { + type derived []uint8 + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Byte slice not convertible", err) + } + + if bytes.Compare(output.([]byte), []byte("value")) != 0 { + t.Fatalf("Byte slice not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedUnsupportedSlice(t *testing.T) { + type derived []int + + _, err := converter{}.ConvertValue(derived{1}) + if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { + t.Fatal("Unexpected error", err) + } +} + +func TestConvertDerivedBool(t *testing.T) { + type derived bool + + output, err := converter{}.ConvertValue(derived(true)) + if err != nil { + t.Fatal("Derived bool type not convertible", err) + } + + if output != true { + t.Fatalf("Derived bool type not converted, got %#v %T", output, output) + } +} + +func TestConvertPointer(t *testing.T) { + str := "value" + + output, err := converter{}.ConvertValue(&str) + if err != nil { + t.Fatal("Pointer type not convertible", err) + } + + if output != "value" { + t.Fatalf("Pointer type not converted, got %#v %T", output, output) + } +} + +func TestConvertSignedIntegers(t *testing.T) { + values := []interface{}{ + int8(-42), + int16(-42), + int32(-42), + int64(-42), + int(-42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(-42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } +} + +func TestConvertUnsignedIntegers(t *testing.T) { + values := []interface{}{ + uint8(42), + uint16(42), + uint32(42), + uint64(42), + uint(42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != uint64(42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } + + output, err := converter{}.ConvertValue(^uint64(0)) + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if output != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/transaction.go b/vendor/github.com/go-sql-driver/mysql/transaction.go index 33c749b..417d727 100644 --- a/vendor/github.com/go-sql-driver/mysql/transaction.go +++ b/vendor/github.com/go-sql-driver/mysql/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") diff --git a/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/go-sql-driver/mysql/utils.go index 6a26ad1..9552e80 100644 --- a/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/go-sql-driver/mysql/utils.go @@ -9,35 +9,32 @@ package mysql import ( - "crypto/sha1" "crypto/tls" + "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" - "net" - "net/url" + "strconv" "strings" + "sync" + "sync/atomic" "time" ) +// Registry for custom tls.Configs var ( - tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs - - errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") - errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") - errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") - errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset") + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config ) -func init() { - tlsConfigRegister = make(map[string]*tls.Config) -} - // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // +// Note: The provided tls.Config is exclusively owned by the driver after +// registering it. +// // rootCertPool := x509.NewCertPool() // pem, err := ioutil.ReadFile("/path/ca-cert.pem") // if err != nil { @@ -59,243 +56,35 @@ func init() { // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { - if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { - return fmt.Errorf("Key '%s' is reserved", key) + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { + return fmt.Errorf("key '%s' is reserved", key) } - tlsConfigRegister[key] = config + tlsConfigLock.Lock() + if tlsConfigRegistry == nil { + tlsConfigRegistry = make(map[string]*tls.Config) + } + + tlsConfigRegistry[key] = config + tlsConfigLock.Unlock() return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { - delete(tlsConfigRegister, key) + tlsConfigLock.Lock() + if tlsConfigRegistry != nil { + delete(tlsConfigRegistry, key) + } + tlsConfigLock.Unlock() } -// parseDSN parses the DSN string to a config -func parseDSN(dsn string) (cfg *config, err error) { - // New config with some default values - cfg = &config{ - loc: time.UTC, - collation: defaultCollation, +func getTLSConfigClone(key string) (config *tls.Config) { + tlsConfigLock.RLock() + if v, ok := tlsConfigRegistry[key]; ok { + config = v.Clone() } - - // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] - // Find the last '/' (since the password or the net addr might contain a '/') - foundSlash := false - for i := len(dsn) - 1; i >= 0; i-- { - if dsn[i] == '/' { - foundSlash = true - var j, k int - - // left part is empty if i <= 0 - if i > 0 { - // [username[:password]@][protocol[(address)]] - // Find the last '@' in dsn[:i] - for j = i; j >= 0; j-- { - if dsn[j] == '@' { - // username[:password] - // Find the first ':' in dsn[:j] - for k = 0; k < j; k++ { - if dsn[k] == ':' { - cfg.passwd = dsn[k+1 : j] - break - } - } - cfg.user = dsn[:k] - - break - } - } - - // [protocol[(address)]] - // Find the first '(' in dsn[j+1:i] - for k = j + 1; k < i; k++ { - if dsn[k] == '(' { - // dsn[i-1] must be == ')' if an address is specified - if dsn[i-1] != ')' { - if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped - } - return nil, errInvalidDSNAddr - } - cfg.addr = dsn[k+1 : i-1] - break - } - } - cfg.net = dsn[j+1 : k] - } - - // dbname[?param1=value1&...¶mN=valueN] - // Find the first '?' in dsn[i+1:] - for j = i + 1; j < len(dsn); j++ { - if dsn[j] == '?' { - if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { - return - } - break - } - } - cfg.dbname = dsn[i+1 : j] - - break - } - } - - if !foundSlash && len(dsn) > 0 { - return nil, errInvalidDSNNoSlash - } - - if cfg.interpolateParams && unsafeCollations[cfg.collation] { - return nil, errInvalidDSNUnsafeCollation - } - - // Set default network if empty - if cfg.net == "" { - cfg.net = "tcp" - } - - // Set default address if empty - if cfg.addr == "" { - switch cfg.net { - case "tcp": - cfg.addr = "127.0.0.1:3306" - case "unix": - cfg.addr = "/tmp/mysql.sock" - default: - return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") - } - - } - - return -} - -// parseDSNParams parses the DSN "query string" -// Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { - for _, v := range strings.Split(params, "&") { - param := strings.SplitN(v, "=", 2) - if len(param) != 2 { - continue - } - - // cfg params - switch value := param[1]; param[0] { - - // Enable client side placeholder substitution - case "interpolateParams": - var isBool bool - cfg.interpolateParams, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Disable INFILE whitelist / enable all files - case "allowAllFiles": - var isBool bool - cfg.allowAllFiles, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Use cleartext authentication mode (MySQL 5.5.10+) - case "allowCleartextPasswords": - var isBool bool - cfg.allowCleartextPasswords, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Use old authentication mode (pre MySQL 4.1) - case "allowOldPasswords": - var isBool bool - cfg.allowOldPasswords, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Switch "rowsAffected" mode - case "clientFoundRows": - var isBool bool - cfg.clientFoundRows, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Collation - case "collation": - collation, ok := collations[value] - if !ok { - // Note possibility for false negatives: - // could be triggered although the collation is valid if the - // collations map does not contain entries the server supports. - err = errors.New("unknown collation") - return - } - cfg.collation = collation - break - - case "columnsWithAlias": - var isBool bool - cfg.columnsWithAlias, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Time Location - case "loc": - if value, err = url.QueryUnescape(value); err != nil { - return - } - cfg.loc, err = time.LoadLocation(value) - if err != nil { - return - } - - // Dial Timeout - case "timeout": - cfg.timeout, err = time.ParseDuration(value) - if err != nil { - return - } - - // TLS-Encryption - case "tls": - boolValue, isBool := readBool(value) - if isBool { - if boolValue { - cfg.tls = &tls.Config{} - } - } else { - if strings.ToLower(value) == "skip-verify" { - cfg.tls = &tls.Config{InsecureSkipVerify: true} - } else if tlsConfig, ok := tlsConfigRegister[value]; ok { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.addr) - if err == nil { - tlsConfig.ServerName = host - } - } - - cfg.tls = tlsConfig - } else { - return fmt.Errorf("Invalid value / unknown config name: %s", value) - } - } - - default: - // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) - } - - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { - return - } - } - } - + tlsConfigLock.RUnlock() return } @@ -313,177 +102,10 @@ func readBool(input string) (value bool, valid bool) { return } -/****************************************************************************** -* Authentication * -******************************************************************************/ - -// Encrypt password using 4.1+ method -func scramblePassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - // stage1Hash = SHA1(password) - crypt := sha1.New() - crypt.Write(password) - stage1 := crypt.Sum(nil) - - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash - crypt.Reset() - crypt.Write(stage1) - hash := crypt.Sum(nil) - - // outer Hash - crypt.Reset() - crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) - - // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] - } - return scramble -} - -// Encrypt password using pre 4.1 (old password) method -// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c -type myRnd struct { - seed1, seed2 uint32 -} - -const myRndMaxVal = 0x3FFFFFFF - -// Pseudo random number generator -func newMyRnd(seed1, seed2 uint32) *myRnd { - return &myRnd{ - seed1: seed1 % myRndMaxVal, - seed2: seed2 % myRndMaxVal, - } -} - -// Tested to be equivalent to MariaDB's floating point variant -// http://play.golang.org/p/QHvhd4qved -// http://play.golang.org/p/RG0q4ElWDx -func (r *myRnd) NextByte() byte { - r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal - r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal - - return byte(uint64(r.seed1) * 31 / myRndMaxVal) -} - -// Generate binary hash from byte string using insecure pre 4.1 method -func pwHash(password []byte) (result [2]uint32) { - var add uint32 = 7 - var tmp uint32 - - result[0] = 1345345333 - result[1] = 0x12345671 - - for _, c := range password { - // skip spaces and tabs in password - if c == ' ' || c == '\t' { - continue - } - - tmp = uint32(c) - result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) - result[1] += (result[1] << 8) ^ result[0] - add += tmp - } - - // Remove sign bit (1<<31)-1) - result[0] &= 0x7FFFFFFF - result[1] &= 0x7FFFFFFF - - return -} - -// Encrypt password using insecure pre 4.1 method -func scrambleOldPassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - scramble = scramble[:8] - - hashPw := pwHash(password) - hashSc := pwHash(scramble) - - r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) - - var out [8]byte - for i := range out { - out[i] = r.NextByte() + 64 - } - - mask := r.NextByte() - for i := range out { - out[i] ^= mask - } - - return out[:] -} - /****************************************************************************** * Time related utils * ******************************************************************************/ -// NullTime represents a time.Time that may be NULL. -// NullTime implements the Scanner interface so -// it can be used as a scan destination: -// -// var nt NullTime -// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) -// ... -// if nt.Valid { -// // use nt.Time -// } else { -// // NULL value -// } -// -// This NullTime implementation is not driver-specific -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -// Scan implements the Scanner interface. -// The value type must be time.Time or string / []byte (formatted time-string), -// otherwise Scan fails. -func (nt *NullTime) Scan(value interface{}) (err error) { - if value == nil { - nt.Time, nt.Valid = time.Time{}, false - return - } - - switch v := value.(type) { - case time.Time: - nt.Time, nt.Valid = v, true - return - case []byte: - nt.Time, err = parseDateTime(string(v), time.UTC) - nt.Valid = (err == nil) - return - case string: - nt.Time, err = parseDateTime(v, time.UTC) - nt.Valid = (err == nil) - return - } - - nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) -} - -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} - func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { base := "0000-00-00 00:00:00.0000000" switch len(str) { @@ -493,7 +115,7 @@ func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { } t, err = time.Parse(timeFormat[:len(str)], str) default: - err = fmt.Errorf("Invalid Time-String: %s", str) + err = fmt.Errorf("invalid time string: %s", str) return } @@ -542,7 +164,7 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va loc, ), nil } - return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) + return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } // zeroDateTime is used in formatBinaryDateTime to avoid an allocation @@ -554,87 +176,104 @@ var zeroDateTime = []byte("0000-00-00 00:00:00.000000") const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" -func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { +func appendMicrosecs(dst, src []byte, decimals int) []byte { + if decimals <= 0 { + return dst + } + if len(src) == 0 { + return append(dst, ".000000"[:decimals+1]...) + } + + microsecs := binary.LittleEndian.Uint32(src[:4]) + p1 := byte(microsecs / 10000) + microsecs -= 10000 * uint32(p1) + p2 := byte(microsecs / 100) + microsecs -= 100 * uint32(p2) + p3 := byte(microsecs) + + switch decimals { + default: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], digits01[p3], + ) + case 1: + return append(dst, '.', + digits10[p1], + ) + case 2: + return append(dst, '.', + digits10[p1], digits01[p1], + ) + case 3: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], + ) + case 4: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + ) + case 5: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], + ) + } +} + +func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { // length expects the deterministic length of the zero value, // negative time and 100+ hours are automatically added if needed if len(src) == 0 { - if justTime { - return zeroDateTime[11 : 11+length], nil - } return zeroDateTime[:length], nil } - var dst []byte // return value - var pt, p1, p2, p3 byte // current digit pair - var zOffs byte // offset of value in zeroDateTime - if justTime { - switch length { - case - 8, // time (can be up to 10 when negative and 100+ hours) - 10, 11, 12, 13, 14, 15: // time with fractional seconds - default: - return nil, fmt.Errorf("illegal TIME length %d", length) + var dst []byte // return value + var p1, p2, p3 byte // current digit pair + + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" } - switch len(src) { - case 8, 12: - default: - return nil, fmt.Errorf("Invalid TIME-packet length %d", len(src)) - } - // +2 to enable negative time and 100+ hours - dst = make([]byte, 0, length+2) - if src[0] == 1 { - dst = append(dst, '-') - } - if src[1] != 0 { - hour := uint16(src[1])*24 + uint16(src[5]) - pt = byte(hour / 100) - p1 = byte(hour - 100*uint16(pt)) - dst = append(dst, digits01[pt]) - } else { - p1 = src[5] - } - zOffs = 11 - src = src[6:] - } else { - switch length { - case 10, 19, 21, 22, 23, 24, 25, 26: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s length %d", t, length) - } - switch len(src) { - case 4, 7, 11: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s-packet length %d", t, len(src)) - } - dst = make([]byte, 0, length) - // start with the date - year := binary.LittleEndian.Uint16(src[:2]) - pt = byte(year / 100) - p1 = byte(year - 100*uint16(pt)) - p2, p3 = src[2], src[3] - dst = append(dst, - digits10[pt], digits01[pt], - digits10[p1], digits01[p1], '-', - digits10[p2], digits01[p2], '-', - digits10[p3], digits01[p3], - ) - if length == 10 { - return dst, nil - } - if len(src) == 4 { - return append(dst, zeroDateTime[10:length]...), nil - } - dst = append(dst, ' ') - p1 = src[4] // hour - src = src[5:] + return nil, fmt.Errorf("illegal %s length %d", t, length) } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt := year / 100 + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + // p1 is 2-digit hour, src is after hour p2, p3 = src[0], src[1] dst = append(dst, @@ -642,51 +281,49 @@ func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value digits10[p2], digits01[p2], ':', digits10[p3], digits01[p3], ) - if length <= byte(len(dst)) { - return dst, nil - } - src = src[2:] + return appendMicrosecs(dst, src[2:], int(length)-20), nil +} + +func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed if len(src) == 0 { - return append(dst, zeroDateTime[19:zOffs+length]...), nil + return zeroDateTime[11 : 11+length], nil } - microsecs := binary.LittleEndian.Uint32(src[:4]) - p1 = byte(microsecs / 10000) - microsecs -= 10000 * uint32(p1) - p2 = byte(microsecs / 100) - microsecs -= 100 * uint32(p2) - p3 = byte(microsecs) - switch decimals := zOffs + length - 20; decimals { + var dst []byte // return value + + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds default: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - digits10[p3], digits01[p3], - ), nil - case 1: - return append(dst, '.', - digits10[p1], - ), nil - case 2: - return append(dst, '.', - digits10[p1], digits01[p1], - ), nil - case 3: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], - ), nil - case 4: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - ), nil - case 5: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - digits10[p3], - ), nil + return nil, fmt.Errorf("illegal TIME length %d", length) } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) + } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + days := binary.LittleEndian.Uint32(src[1:5]) + hours := int64(days)*24 + int64(src[5]) + + if hours >= 100 { + dst = strconv.AppendInt(dst, hours, 10) + } else { + dst = append(dst, digits10[hours], digits01[hours]) + } + + min, sec := src[6], src[7] + dst = append(dst, ':', + digits10[min], digits01[min], ':', + digits10[sec], digits01[sec], + ) + return appendMicrosecs(dst, src[8:], int(length)-9), nil } /****************************************************************************** @@ -752,7 +389,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { // Check data length if len(b) >= n { - return b[n-int(num) : n], false, n, nil + return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } @@ -781,8 +418,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - switch b[0] { + switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 @@ -877,7 +514,7 @@ func escapeBytesBackslash(buf, v []byte) []byte { pos += 2 default: buf[pos] = c - pos += 1 + pos++ } } @@ -922,7 +559,7 @@ func escapeStringBackslash(buf []byte, v string) []byte { pos += 2 default: buf[pos] = c - pos += 1 + pos++ } } @@ -971,3 +608,94 @@ func escapeStringQuotes(buf []byte, v string) []byte { return buf[:pos] } + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} + +// atomicBool is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _noCopy noCopy + value uint32 +} + +// IsSet returns whether the current boolean value is true +func (ab *atomicBool) IsSet() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Set sets the value of the bool regardless of the previous value +func (ab *atomicBool) Set(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// TrySet sets the value of the bool and returns whether the value changed +func (ab *atomicBool) TrySet(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) == 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} + +// atomicError is a wrapper for atomically accessed error values +type atomicError struct { + _noCopy noCopy + value atomic.Value +} + +// Set sets the error value regardless of the previous value. +// The value must not be nil +func (ae *atomicError) Set(value error) { + ae.value.Store(value) +} + +// Value returns the current error value +func (ae *atomicError) Value() error { + if v := ae.value.Load(); v != nil { + // this will panic if the value doesn't implement the error interface + return v.(error) + } + return nil +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_test.go b/vendor/github.com/go-sql-driver/mysql/utils_test.go new file mode 100644 index 0000000..10a60c2 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/utils_test.go @@ -0,0 +1,293 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/binary" + "testing" +) + +func TestLengthEncodedInteger(t *testing.T) { + var integerTests = []struct { + num uint64 + encoded []byte + }{ + {0x0000000000000000, []byte{0x00}}, + {0x0000000000000012, []byte{0x12}}, + {0x00000000000000fa, []byte{0xfa}}, + {0x0000000000000100, []byte{0xfc, 0x00, 0x01}}, + {0x0000000000001234, []byte{0xfc, 0x34, 0x12}}, + {0x000000000000ffff, []byte{0xfc, 0xff, 0xff}}, + {0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}}, + {0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}}, + {0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}}, + {0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}}, + {0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}}, + {0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + + for _, tst := range integerTests { + num, isNull, numLen := readLengthEncodedInteger(tst.encoded) + if isNull { + t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num) + } + if num != tst.num { + t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num) + } + if numLen != len(tst.encoded) { + t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen) + } + encoded := appendLengthEncodedInteger(nil, num) + if !bytes.Equal(encoded, tst.encoded) { + t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded) + } + } +} + +func TestFormatBinaryDateTime(t *testing.T) { + rawDate := [11]byte{} + binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years + rawDate[2] = 12 // months + rawDate[3] = 30 // days + rawDate[4] = 15 // hours + rawDate[5] = 46 // minutes + rawDate[6] = 23 // seconds + binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds + expect := func(expected string, inlen, outlen uint8) { + actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for length in %d, out %d", + expected, actual, inlen, outlen, + ) + } + } + expect("0000-00-00", 0, 10) + expect("0000-00-00 00:00:00", 0, 19) + expect("1978-12-30", 4, 10) + expect("1978-12-30 15:46:23", 7, 19) + expect("1978-12-30 15:46:23.987654", 11, 26) +} + +func TestFormatBinaryTime(t *testing.T) { + expect := func(expected string, src []byte, outlen uint8) { + actual, _ := formatBinaryTime(src, outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for src=%q and outlen=%d", + expected, actual, src, outlen) + } + } + + // binary format: + // sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4) + + // Zeros + expect("00:00:00", []byte{}, 8) + expect("00:00:00.0", []byte{}, 10) + expect("00:00:00.000000", []byte{}, 15) + + // Without micro(4) + expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11) + expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8) + expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8) + expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8) + + // With micro(4) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11) + expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15) +} + +func TestEscapeBackslash(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesBackslash([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringBackslash([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\\0bar", "foo\x00bar") + expect("foo\\nbar", "foo\nbar") + expect("foo\\rbar", "foo\rbar") + expect("foo\\Zbar", "foo\x1abar") + expect("foo\\\"bar", "foo\"bar") + expect("foo\\\\bar", "foo\\bar") + expect("foo\\'bar", "foo'bar") +} + +func TestEscapeQuotes(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesQuotes([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringQuotes([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\x00bar", "foo\x00bar") // not affected + expect("foo\nbar", "foo\nbar") // not affected + expect("foo\rbar", "foo\rbar") // not affected + expect("foo\x1abar", "foo\x1abar") // not affected + expect("foo''bar", "foo'bar") // affected + expect("foo\"bar", "foo\"bar") // not affected +} + +func TestAtomicBool(t *testing.T) { + var ab atomicBool + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(true) + if ab.value != 1 { + t.Fatal("Set(true) did not set value to 1") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(false) + if ab.value != 0 { + t.Fatal("Set(false) did not set value to 0") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(false) + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + if ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to fail") + } + if !ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to succeed") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + if ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to fail") + } + if !ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to succeed") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯ +} + +func TestAtomicError(t *testing.T) { + var ae atomicError + if ae.Value() != nil { + t.Fatal("Expected value to be nil") + } + + ae.Set(ErrMalformPkt) + if v := ae.Value(); v != ErrMalformPkt { + if v == nil { + t.Fatal("Value is still nil") + } + t.Fatal("Error did not match") + } + ae.Set(ErrPktSync) + if ae.Value() == ErrMalformPkt { + t.Fatal("Error still matches old error") + } + if v := ae.Value(); v != ErrPktSync { + t.Fatal("Error did not match") + } +} + +func TestIsolationLevelMapping(t *testing.T) { + data := []struct { + level driver.IsolationLevel + expected string + }{ + { + level: driver.IsolationLevel(sql.LevelReadCommitted), + expected: "READ COMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelRepeatableRead), + expected: "REPEATABLE READ", + }, + { + level: driver.IsolationLevel(sql.LevelReadUncommitted), + expected: "READ UNCOMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelSerializable), + expected: "SERIALIZABLE", + }, + } + + for i, td := range data { + if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil { + t.Fatal(i, td.expected, actual, err) + } + } + + // check unsupported mapping + expectedErr := "mysql: unsupported isolation level: 7" + actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable)) + if actual != "" || err == nil { + t.Fatal("Expected error on unsupported isolation level") + } + if err.Error() != expectedErr { + t.Fatalf("Expected error to be %q, got %q", expectedErr, err) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/.travis.yml b/vendor/github.com/siddontang/go-mysql/.travis.yml index cc0db3c..8f8eafd 100644 --- a/vendor/github.com/siddontang/go-mysql/.travis.yml +++ b/vendor/github.com/siddontang/go-mysql/.travis.yml @@ -1,32 +1,34 @@ language: go go: - - 1.6 - - 1.7 + - "1.9" + - "1.10" -dist: trusty -sudo: required addons: apt: + sources: + - mysql-5.7-trusty packages: - - mysql-server-5.6 - - mysql-client-core-5.6 - - mysql-client-5.6 + - mysql-server + - mysql-client + +before_install: + - sudo mysql -e "use mysql; update user set authentication_string=PASSWORD('') where User='root'; update user set plugin='mysql_native_password';FLUSH PRIVILEGES;" + - sudo mysql_upgrade -before_script: # stop mysql and use row-based format binlog - - "sudo /etc/init.d/mysql stop || true" + - "sudo service mysql stop || true" - "echo '[mysqld]' | sudo tee /etc/mysql/conf.d/replication.cnf" - "echo 'server-id=1' | sudo tee -a /etc/mysql/conf.d/replication.cnf" - - "echo 'log-bin=mysql' | sudo tee -a /etc/mysql/conf.d/replication.cnf" + - "echo 'log-bin=mysql' | sudo tee -a /etc/mysql/conf.d/replication.cnf" - "echo 'binlog-format = row' | sudo tee -a /etc/mysql/conf.d/replication.cnf" # Start mysql (avoid errors to have logs) - - "sudo /etc/init.d/mysql start || true" + - "sudo service mysql start || true" - "sudo tail -1000 /var/log/syslog" - mysql -e "CREATE DATABASE IF NOT EXISTS test;" -uroot script: - - make test \ No newline at end of file + - make test diff --git a/vendor/github.com/siddontang/go-mysql/Gopkg.lock b/vendor/github.com/siddontang/go-mysql/Gopkg.lock new file mode 100644 index 0000000..ae65b1d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/Gopkg.lock @@ -0,0 +1,78 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + name = "github.com/BurntSushi/toml" + packages = ["."] + revision = "b26d9c308763d68093482582cea63d69be07a0f0" + version = "v0.3.0" + +[[projects]] + branch = "master" + name = "github.com/go-sql-driver/mysql" + packages = ["."] + revision = "99ff426eb706cffe92ff3d058e168b278cabf7c7" + +[[projects]] + branch = "master" + name = "github.com/jmoiron/sqlx" + packages = [ + ".", + "reflectx" + ] + revision = "2aeb6a910c2b94f2d5eb53d9895d80e27264ec41" + +[[projects]] + branch = "master" + name = "github.com/juju/errors" + packages = ["."] + revision = "c7d06af17c68cd34c835053720b21f6549d9b0ee" + +[[projects]] + branch = "master" + name = "github.com/pingcap/check" + packages = ["."] + revision = "1c287c953996ab3a0bf535dba9d53d809d3dc0b6" + +[[projects]] + name = "github.com/satori/go.uuid" + packages = ["."] + revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3" + version = "v1.2.0" + +[[projects]] + name = "github.com/shopspring/decimal" + packages = ["."] + revision = "cd690d0c9e2447b1ef2a129a6b7b49077da89b8e" + version = "1.1.0" + +[[projects]] + branch = "master" + name = "github.com/siddontang/go" + packages = [ + "hack", + "sync2" + ] + revision = "2b7082d296ba89ae7ead0f977816bddefb65df9d" + +[[projects]] + branch = "master" + name = "github.com/siddontang/go-log" + packages = [ + "log", + "loggers" + ] + revision = "a4d157e46fa3e08b7e7ff329af341fa3ff86c02c" + +[[projects]] + name = "google.golang.org/appengine" + packages = ["cloudsql"] + revision = "b1f26356af11148e710935ed1ac8a7f5702c7612" + version = "v1.1.0" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + inputs-digest = "a1f9939938a58551bbb3f19411c9d1386995d36296de6f6fb5d858f5923db85e" + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/vendor/github.com/siddontang/go-mysql/Gopkg.toml b/vendor/github.com/siddontang/go-mysql/Gopkg.toml new file mode 100644 index 0000000..71df4b3 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/Gopkg.toml @@ -0,0 +1,56 @@ +# Gopkg.toml example +# +# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + + +[[constraint]] + name = "github.com/BurntSushi/toml" + version = "v0.3.0" + +[[constraint]] + name = "github.com/go-sql-driver/mysql" + branch = "master" + +[[constraint]] + branch = "master" + name = "github.com/juju/errors" + +[[constraint]] + name = "github.com/satori/go.uuid" + version = "v1.2.0" + +[[constraint]] + name = "github.com/shopspring/decimal" + version = "v1.1.0" + +[[constraint]] + branch = "master" + name = "github.com/siddontang/go" + +[prune] + go-tests = true + unused-packages = true + non-go = true + diff --git a/vendor/github.com/siddontang/go-mysql/Makefile b/vendor/github.com/siddontang/go-mysql/Makefile index 92744b1..3decd6c 100644 --- a/vendor/github.com/siddontang/go-mysql/Makefile +++ b/vendor/github.com/siddontang/go-mysql/Makefile @@ -1,33 +1,14 @@ all: build build: - rm -rf vendor && ln -s _vendor/vendor vendor go build -o bin/go-mysqlbinlog cmd/go-mysqlbinlog/main.go go build -o bin/go-mysqldump cmd/go-mysqldump/main.go go build -o bin/go-canal cmd/go-canal/main.go go build -o bin/go-binlogparser cmd/go-binlogparser/main.go - rm -rf vendor - + test: - rm -rf vendor && ln -s _vendor/vendor vendor go test --race -timeout 2m ./... - rm -rf vendor clean: go clean -i ./... - @rm -rf ./bin - -update_vendor: - which glide >/dev/null || curl https://glide.sh/get | sh - which glide-vc || go get -v -u github.com/sgotti/glide-vc - rm -r vendor && mv _vendor/vendor vendor || true - rm -rf _vendor -ifdef PKG - glide get --strip-vendor --skip-test ${PKG} -else - glide update --strip-vendor --skip-test -endif - @echo "removing test files" - glide vc --only-code --no-tests - mkdir -p _vendor - mv vendor _vendor/vendor + @rm -rf ./bin \ No newline at end of file diff --git a/vendor/github.com/siddontang/go-mysql/README.md b/vendor/github.com/siddontang/go-mysql/README.md index 4ae6697..0b958c7 100644 --- a/vendor/github.com/siddontang/go-mysql/README.md +++ b/vendor/github.com/siddontang/go-mysql/README.md @@ -25,9 +25,9 @@ cfg := replication.BinlogSyncerConfig { User: "root", Password: "", } -syncer := replication.NewBinlogSyncer(&cfg) +syncer := replication.NewBinlogSyncer(cfg) -// Start sync with sepcified binlog file and position +// Start sync with specified binlog file and position streamer, _ := syncer.StartSync(mysql.Position{binlogFile, binlogPos}) // or you can start a gtid replication like @@ -44,7 +44,7 @@ for { // or we can use a timeout context for { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - e, _ := s.GetEvent(ctx) + ev, err := s.GetEvent(ctx) cancel() if err == context.DeadlineExceeded { @@ -105,20 +105,21 @@ cfg.Dump.Tables = []string{"canal_test"} c, err := NewCanal(cfg) -type myRowsEventHandler struct { +type MyEventHandler struct { + DummyEventHandler } -func (h *myRowsEventHandler) Do(e *RowsEvent) error { +func (h *MyEventHandler) OnRow(e *RowsEvent) error { log.Infof("%s %v\n", e.Action, e.Rows) return nil } -func (h *myRowsEventHandler) String() string { - return "myRowsEventHandler" +func (h *MyEventHandler) String() string { + return "MyEventHandler" } // Register a handler to handle RowsEvent -c.RegRowsEventHandler(&MyRowsEventHandler{}) +c.SetEventHandler(&MyEventHandler{}) // Start canal c.Start() @@ -137,9 +138,16 @@ import ( "github.com/siddontang/go-mysql/client" ) -// Connect MySQL at 127.0.0.1:3306, with user root, an empty passowrd and database test +// Connect MySQL at 127.0.0.1:3306, with user root, an empty password and database test conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test") +// Or to use SSL/TLS connection if MySQL server supports TLS +//conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test", func(c *Conn) {c.UseSSL(true)}) + +// or to set your own client-side certificates for identity verification for security +//tlsConfig := NewClientTLSConfig(caPem, certPem, keyPem, false, "your-server-name") +//conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test", func(c *Conn) {c.SetTLSConfig(tlsConfig)}) + conn.Ping() // Insert @@ -156,10 +164,17 @@ v, _ := r.GetInt(0, 0) v, _ = r.GetIntByName(0, "id") ``` +Tested MySQL versions for the client include: +- 5.5.x +- 5.6.x +- 5.7.x +- 8.0.x + ## Server Server package supplies a framework to implement a simple MySQL server which can handle the packets from the MySQL client. -You can use it to build your own MySQL proxy. +You can use it to build your own MySQL proxy. The server connection is compatible with MySQL 5.5, 5.6, 5.7, and 8.0 versions, +so that most MySQL clients should be able to connect to the Server without modifications. ### Example @@ -173,14 +188,14 @@ l, _ := net.Listen("tcp", "127.0.0.1:4000") c, _ := l.Accept() -// Create a connection with user root and an empty passowrd -// We only an empty handler to handle command too +// Create a connection with user root and an empty password. +// You can use your own handler to handle command here. conn, _ := server.NewConn(c, "root", "", server.EmptyHandler{}) for { conn.HandleCommand() } -``` +``` Another shell @@ -189,6 +204,15 @@ mysql -h127.0.0.1 -P4000 -uroot -p //Becuase empty handler does nothing, so here the MySQL client can only connect the proxy server. :-) ``` +> ```NewConn()``` will use default server configurations: +> 1. automatically generate default server certificates and enable TLS/SSL support. +> 2. support three mainstream authentication methods **'mysql_native_password'**, **'caching_sha2_password'**, and **'sha256_password'** +> and use **'mysql_native_password'** as default. +> 3. use an in-memory user credential provider to store user and password. +> +> To customize server configurations, use ```NewServer()``` and create connection via ```NewCustomizedConn()```. + + ## Failover Failover supports to promote a new master and let other slaves replicate from it automatically when the old master was down. @@ -205,10 +229,12 @@ Although there are many companies use MySQL 5.0 - 5.5, I think upgrade MySQL to Driver is the package that you can use go-mysql with go database/sql like other drivers. A simple example: ``` +package main + import ( "database/sql" - - "github.com/siddontang/go-mysql/driver" + _ "github.com/siddontang/go-mysql/driver" ) func main() { @@ -221,6 +247,14 @@ func main() { We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) +## Donate + +If you like the project and want to buy me a cola, you can through: + +|PayPal|微信| +|------|---| +|[![](https://www.paypalobjects.com/webstatic/paypalme/images/pp_logo_small.png)](https://paypal.me/siddontang)|[![](https://github.com/siddontang/blog/blob/master/donate/weixin.png)| + ## Feedback go-mysql is still in development, your feedback is very welcome. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/buffer.go deleted file mode 100644 index 2001fea..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/buffer.go +++ /dev/null @@ -1,147 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -package mysql - -import ( - "io" - "net" - "time" -) - -const defaultBufSize = 4096 - -// A buffer which is used for both reading and writing. -// This is possible since communication on each connection is synchronous. -// In other words, we can't write and read simultaneously on the same connection. -// The buffer is similar to bufio.Reader / Writer but zero-copy-ish -// Also highly optimized for this particular use case. -type buffer struct { - buf []byte - nc net.Conn - idx int - length int - timeout time.Duration -} - -func newBuffer(nc net.Conn) buffer { - var b [defaultBufSize]byte - return buffer{ - buf: b[:], - nc: nc, - } -} - -// fill reads into the buffer until at least _need_ bytes are in it -func (b *buffer) fill(need int) error { - n := b.length - - // move existing data to the beginning - if n > 0 && b.idx > 0 { - copy(b.buf[0:n], b.buf[b.idx:]) - } - - // grow buffer if necessary - // TODO: let the buffer shrink again at some point - // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { - // Round up to the next multiple of the default size - newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - copy(newBuf, b.buf) - b.buf = newBuf - } - - b.idx = 0 - - for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { - return err - } - } - - nn, err := b.nc.Read(b.buf[n:]) - n += nn - - switch err { - case nil: - if n < need { - continue - } - b.length = n - return nil - - case io.EOF: - if n >= need { - b.length = n - return nil - } - return io.ErrUnexpectedEOF - - default: - return err - } - } -} - -// returns next N bytes from buffer. -// The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int) ([]byte, error) { - if b.length < need { - // refill - if err := b.fill(need); err != nil { - return nil, err - } - } - - offset := b.idx - b.idx += need - b.length -= need - return b.buf[offset:b.idx], nil -} - -// returns a buffer with the requested size. -// If possible, a slice from the existing buffer is returned. -// Otherwise a bigger buffer is made. -// Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { - if b.length > 0 { - return nil - } - - // test (cheap) general case first - if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] - } - - if length < maxPacketSize { - b.buf = make([]byte, length) - return b.buf - } - return make([]byte, length) -} - -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize -// Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { - return b.buf[:length] - } - return nil -} - -// takeCompleteBuffer returns the complete existing buffer. -// This can be used if the necessary buffer size is unknown. -// Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { - if b.length == 0 { - return b.buf - } - return nil -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/collations.go deleted file mode 100644 index 82079cf..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/collations.go +++ /dev/null @@ -1,250 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -package mysql - -const defaultCollation = "utf8_general_ci" - -// A list of available collations mapped to the internal ID. -// To update this map use the following MySQL query: -// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS -var collations = map[string]byte{ - "big5_chinese_ci": 1, - "latin2_czech_cs": 2, - "dec8_swedish_ci": 3, - "cp850_general_ci": 4, - "latin1_german1_ci": 5, - "hp8_english_ci": 6, - "koi8r_general_ci": 7, - "latin1_swedish_ci": 8, - "latin2_general_ci": 9, - "swe7_swedish_ci": 10, - "ascii_general_ci": 11, - "ujis_japanese_ci": 12, - "sjis_japanese_ci": 13, - "cp1251_bulgarian_ci": 14, - "latin1_danish_ci": 15, - "hebrew_general_ci": 16, - "tis620_thai_ci": 18, - "euckr_korean_ci": 19, - "latin7_estonian_cs": 20, - "latin2_hungarian_ci": 21, - "koi8u_general_ci": 22, - "cp1251_ukrainian_ci": 23, - "gb2312_chinese_ci": 24, - "greek_general_ci": 25, - "cp1250_general_ci": 26, - "latin2_croatian_ci": 27, - "gbk_chinese_ci": 28, - "cp1257_lithuanian_ci": 29, - "latin5_turkish_ci": 30, - "latin1_german2_ci": 31, - "armscii8_general_ci": 32, - "utf8_general_ci": 33, - "cp1250_czech_cs": 34, - "ucs2_general_ci": 35, - "cp866_general_ci": 36, - "keybcs2_general_ci": 37, - "macce_general_ci": 38, - "macroman_general_ci": 39, - "cp852_general_ci": 40, - "latin7_general_ci": 41, - "latin7_general_cs": 42, - "macce_bin": 43, - "cp1250_croatian_ci": 44, - "utf8mb4_general_ci": 45, - "utf8mb4_bin": 46, - "latin1_bin": 47, - "latin1_general_ci": 48, - "latin1_general_cs": 49, - "cp1251_bin": 50, - "cp1251_general_ci": 51, - "cp1251_general_cs": 52, - "macroman_bin": 53, - "utf16_general_ci": 54, - "utf16_bin": 55, - "utf16le_general_ci": 56, - "cp1256_general_ci": 57, - "cp1257_bin": 58, - "cp1257_general_ci": 59, - "utf32_general_ci": 60, - "utf32_bin": 61, - "utf16le_bin": 62, - "binary": 63, - "armscii8_bin": 64, - "ascii_bin": 65, - "cp1250_bin": 66, - "cp1256_bin": 67, - "cp866_bin": 68, - "dec8_bin": 69, - "greek_bin": 70, - "hebrew_bin": 71, - "hp8_bin": 72, - "keybcs2_bin": 73, - "koi8r_bin": 74, - "koi8u_bin": 75, - "latin2_bin": 77, - "latin5_bin": 78, - "latin7_bin": 79, - "cp850_bin": 80, - "cp852_bin": 81, - "swe7_bin": 82, - "utf8_bin": 83, - "big5_bin": 84, - "euckr_bin": 85, - "gb2312_bin": 86, - "gbk_bin": 87, - "sjis_bin": 88, - "tis620_bin": 89, - "ucs2_bin": 90, - "ujis_bin": 91, - "geostd8_general_ci": 92, - "geostd8_bin": 93, - "latin1_spanish_ci": 94, - "cp932_japanese_ci": 95, - "cp932_bin": 96, - "eucjpms_japanese_ci": 97, - "eucjpms_bin": 98, - "cp1250_polish_ci": 99, - "utf16_unicode_ci": 101, - "utf16_icelandic_ci": 102, - "utf16_latvian_ci": 103, - "utf16_romanian_ci": 104, - "utf16_slovenian_ci": 105, - "utf16_polish_ci": 106, - "utf16_estonian_ci": 107, - "utf16_spanish_ci": 108, - "utf16_swedish_ci": 109, - "utf16_turkish_ci": 110, - "utf16_czech_ci": 111, - "utf16_danish_ci": 112, - "utf16_lithuanian_ci": 113, - "utf16_slovak_ci": 114, - "utf16_spanish2_ci": 115, - "utf16_roman_ci": 116, - "utf16_persian_ci": 117, - "utf16_esperanto_ci": 118, - "utf16_hungarian_ci": 119, - "utf16_sinhala_ci": 120, - "utf16_german2_ci": 121, - "utf16_croatian_ci": 122, - "utf16_unicode_520_ci": 123, - "utf16_vietnamese_ci": 124, - "ucs2_unicode_ci": 128, - "ucs2_icelandic_ci": 129, - "ucs2_latvian_ci": 130, - "ucs2_romanian_ci": 131, - "ucs2_slovenian_ci": 132, - "ucs2_polish_ci": 133, - "ucs2_estonian_ci": 134, - "ucs2_spanish_ci": 135, - "ucs2_swedish_ci": 136, - "ucs2_turkish_ci": 137, - "ucs2_czech_ci": 138, - "ucs2_danish_ci": 139, - "ucs2_lithuanian_ci": 140, - "ucs2_slovak_ci": 141, - "ucs2_spanish2_ci": 142, - "ucs2_roman_ci": 143, - "ucs2_persian_ci": 144, - "ucs2_esperanto_ci": 145, - "ucs2_hungarian_ci": 146, - "ucs2_sinhala_ci": 147, - "ucs2_german2_ci": 148, - "ucs2_croatian_ci": 149, - "ucs2_unicode_520_ci": 150, - "ucs2_vietnamese_ci": 151, - "ucs2_general_mysql500_ci": 159, - "utf32_unicode_ci": 160, - "utf32_icelandic_ci": 161, - "utf32_latvian_ci": 162, - "utf32_romanian_ci": 163, - "utf32_slovenian_ci": 164, - "utf32_polish_ci": 165, - "utf32_estonian_ci": 166, - "utf32_spanish_ci": 167, - "utf32_swedish_ci": 168, - "utf32_turkish_ci": 169, - "utf32_czech_ci": 170, - "utf32_danish_ci": 171, - "utf32_lithuanian_ci": 172, - "utf32_slovak_ci": 173, - "utf32_spanish2_ci": 174, - "utf32_roman_ci": 175, - "utf32_persian_ci": 176, - "utf32_esperanto_ci": 177, - "utf32_hungarian_ci": 178, - "utf32_sinhala_ci": 179, - "utf32_german2_ci": 180, - "utf32_croatian_ci": 181, - "utf32_unicode_520_ci": 182, - "utf32_vietnamese_ci": 183, - "utf8_unicode_ci": 192, - "utf8_icelandic_ci": 193, - "utf8_latvian_ci": 194, - "utf8_romanian_ci": 195, - "utf8_slovenian_ci": 196, - "utf8_polish_ci": 197, - "utf8_estonian_ci": 198, - "utf8_spanish_ci": 199, - "utf8_swedish_ci": 200, - "utf8_turkish_ci": 201, - "utf8_czech_ci": 202, - "utf8_danish_ci": 203, - "utf8_lithuanian_ci": 204, - "utf8_slovak_ci": 205, - "utf8_spanish2_ci": 206, - "utf8_roman_ci": 207, - "utf8_persian_ci": 208, - "utf8_esperanto_ci": 209, - "utf8_hungarian_ci": 210, - "utf8_sinhala_ci": 211, - "utf8_german2_ci": 212, - "utf8_croatian_ci": 213, - "utf8_unicode_520_ci": 214, - "utf8_vietnamese_ci": 215, - "utf8_general_mysql500_ci": 223, - "utf8mb4_unicode_ci": 224, - "utf8mb4_icelandic_ci": 225, - "utf8mb4_latvian_ci": 226, - "utf8mb4_romanian_ci": 227, - "utf8mb4_slovenian_ci": 228, - "utf8mb4_polish_ci": 229, - "utf8mb4_estonian_ci": 230, - "utf8mb4_spanish_ci": 231, - "utf8mb4_swedish_ci": 232, - "utf8mb4_turkish_ci": 233, - "utf8mb4_czech_ci": 234, - "utf8mb4_danish_ci": 235, - "utf8mb4_lithuanian_ci": 236, - "utf8mb4_slovak_ci": 237, - "utf8mb4_spanish2_ci": 238, - "utf8mb4_roman_ci": 239, - "utf8mb4_persian_ci": 240, - "utf8mb4_esperanto_ci": 241, - "utf8mb4_hungarian_ci": 242, - "utf8mb4_sinhala_ci": 243, - "utf8mb4_german2_ci": 244, - "utf8mb4_croatian_ci": 245, - "utf8mb4_unicode_520_ci": 246, - "utf8mb4_vietnamese_ci": 247, -} - -// A blacklist of collations which is unsafe to interpolate parameters. -// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. -var unsafeCollations = map[string]bool{ - "big5_chinese_ci": true, - "sjis_japanese_ci": true, - "gbk_chinese_ci": true, - "big5_bin": true, - "gb2312_bin": true, - "gbk_bin": true, - "sjis_bin": true, - "cp932_japanese_ci": true, - "cp932_bin": true, -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/driver.go deleted file mode 100644 index 899f955..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/driver.go +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// Package mysql provides a MySQL driver for Go's database/sql package -// -// The driver should be used via the database/sql package: -// -// import "database/sql" -// import _ "github.com/go-sql-driver/mysql" -// -// db, err := sql.Open("mysql", "user:password@/dbname") -// -// See https://github.com/go-sql-driver/mysql#usage for details -package mysql - -import ( - "database/sql" - "database/sql/driver" - "net" -) - -// MySQLDriver is exported to make the driver directly accessible. -// In general the driver is used via the database/sql package. -type MySQLDriver struct{} - -// DialFunc is a function which can be used to establish the network connection. -// Custom dial functions must be registered with RegisterDial -type DialFunc func(addr string) (net.Conn, error) - -var dials map[string]DialFunc - -// RegisterDial registers a custom dial function. It can then be used by the -// network address mynet(addr), where mynet is the registered new network. -// addr is passed as a parameter to the dial function. -func RegisterDial(net string, dial DialFunc) { - if dials == nil { - dials = make(map[string]DialFunc) - } - dials[net] = dial -} - -// Open new Connection. -// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated -func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { - var err error - - // New mysqlConn - mc := &mysqlConn{ - maxPacketAllowed: maxPacketSize, - maxWriteSize: maxPacketSize - 1, - } - mc.cfg, err = ParseDSN(dsn) - if err != nil { - return nil, err - } - mc.parseTime = mc.cfg.ParseTime - mc.strict = mc.cfg.Strict - - // Connect to Server - if dial, ok := dials[mc.cfg.Net]; ok { - mc.netConn, err = dial(mc.cfg.Addr) - } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) - } - if err != nil { - return nil, err - } - - // Enable TCP Keepalives on TCP connections - if tc, ok := mc.netConn.(*net.TCPConn); ok { - if err := tc.SetKeepAlive(true); err != nil { - // Don't send COM_QUIT before handshake. - mc.netConn.Close() - mc.netConn = nil - return nil, err - } - } - - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout - - // Reading Handshake Initialization Packet - cipher, err := mc.readInitPacket() - if err != nil { - mc.cleanup() - return nil, err - } - - // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { - mc.cleanup() - return nil, err - } - - // Handle response to auth packet, switch methods if possible - if err = handleAuthResult(mc, cipher); err != nil { - // Authentication failed and MySQL has already closed the connection - // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). - // Do not send COM_QUIT, just cleanup and return the error. - mc.cleanup() - return nil, err - } - - // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") - if err != nil { - mc.Close() - return nil, err - } - mc.maxPacketAllowed = stringToInt(maxap) - 1 - if mc.maxPacketAllowed < maxPacketSize { - mc.maxWriteSize = mc.maxPacketAllowed - } - - // Handle DSN Params - err = mc.handleParams() - if err != nil { - mc.Close() - return nil, err - } - - return mc, nil -} - -func handleAuthResult(mc *mysqlConn, cipher []byte) error { - // Read Result Packet - err := mc.readResultOK() - if err == nil { - return nil // auth successful - } - - if mc.cfg == nil { - return err // auth failed and retry not possible - } - - // Retry auth if configured to do so. - if mc.cfg.AllowOldPasswords && err == ErrOldPassword { - // Retry with old authentication method. Note: there are edge cases - // where this should work but doesn't; this is currently "wontfix": - // https://github.com/go-sql-driver/mysql/issues/184 - if err = mc.writeOldAuthPacket(cipher); err != nil { - return err - } - err = mc.readResultOK() - } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { - // Retry with clear text password for - // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html - // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(); err != nil { - return err - } - err = mc.readResultOK() - } - return err -} - -func init() { - sql.Register("mysql", &MySQLDriver{}) -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/rows.go deleted file mode 100644 index c08255e..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/rows.go +++ /dev/null @@ -1,112 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -package mysql - -import ( - "database/sql/driver" - "io" -) - -type mysqlField struct { - tableName string - name string - flags fieldFlag - fieldType byte - decimals byte -} - -type mysqlRows struct { - mc *mysqlConn - columns []mysqlField -} - -type binaryRows struct { - mysqlRows -} - -type textRows struct { - mysqlRows -} - -type emptyRows struct{} - -func (rows *mysqlRows) Columns() []string { - columns := make([]string, len(rows.columns)) - if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { - for i := range columns { - if tableName := rows.columns[i].tableName; len(tableName) > 0 { - columns[i] = tableName + "." + rows.columns[i].name - } else { - columns[i] = rows.columns[i].name - } - } - } else { - for i := range columns { - columns[i] = rows.columns[i].name - } - } - return columns -} - -func (rows *mysqlRows) Close() error { - mc := rows.mc - if mc == nil { - return nil - } - if mc.netConn == nil { - return ErrInvalidConn - } - - // Remove unread packets from stream - err := mc.readUntilEOF() - if err == nil { - if err = mc.discardResults(); err != nil { - return err - } - } - - rows.mc = nil - return err -} - -func (rows *binaryRows) Next(dest []driver.Value) error { - if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn - } - - // Fetch next row from stream - return rows.readRow(dest) - } - return io.EOF -} - -func (rows *textRows) Next(dest []driver.Value) error { - if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn - } - - // Fetch next row from stream - return rows.readRow(dest) - } - return io.EOF -} - -func (rows emptyRows) Columns() []string { - return nil -} - -func (rows emptyRows) Close() error { - return nil -} - -func (rows emptyRows) Next(dest []driver.Value) error { - return io.EOF -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/statement.go deleted file mode 100644 index ead9a6b..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/statement.go +++ /dev/null @@ -1,150 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -package mysql - -import ( - "database/sql/driver" - "fmt" - "reflect" - "strconv" -) - -type mysqlStmt struct { - mc *mysqlConn - id uint32 - paramCount int - columns []mysqlField // cached from the first query -} - -func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.netConn == nil { - errLog.Print(ErrInvalidConn) - return driver.ErrBadConn - } - - err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) - stmt.mc = nil - return err -} - -func (stmt *mysqlStmt) NumInput() int { - return stmt.paramCount -} - -func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { - return converter{} -} - -func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.netConn == nil { - errLog.Print(ErrInvalidConn) - return nil, driver.ErrBadConn - } - // Send command - err := stmt.writeExecutePacket(args) - if err != nil { - return nil, err - } - - mc := stmt.mc - - mc.affectedRows = 0 - mc.insertId = 0 - - // Read Result - resLen, err := mc.readResultSetHeaderPacket() - if err == nil { - if resLen > 0 { - // Columns - err = mc.readUntilEOF() - if err != nil { - return nil, err - } - - // Rows - err = mc.readUntilEOF() - } - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil - } - } - - return nil, err -} - -func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.mc.netConn == nil { - errLog.Print(ErrInvalidConn) - return nil, driver.ErrBadConn - } - // Send command - err := stmt.writeExecutePacket(args) - if err != nil { - return nil, err - } - - mc := stmt.mc - - // Read Result - resLen, err := mc.readResultSetHeaderPacket() - if err != nil { - return nil, err - } - - rows := new(binaryRows) - - if resLen > 0 { - rows.mc = mc - // Columns - // If not cached, read them and cache them - if stmt.columns == nil { - rows.columns, err = mc.readColumns(resLen) - stmt.columns = rows.columns - } else { - rows.columns = stmt.columns - err = mc.readUntilEOF() - } - } - - return rows, err -} - -type converter struct{} - -func (c converter) ConvertValue(v interface{}) (driver.Value, error) { - if driver.IsValue(v) { - return v, nil - } - - rv := reflect.ValueOf(v) - switch rv.Kind() { - case reflect.Ptr: - // indirect pointers - if rv.IsNil() { - return nil, nil - } - return c.ConvertValue(rv.Elem().Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return rv.Int(), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return int64(rv.Uint()), nil - case reflect.Uint64: - u64 := rv.Uint() - if u64 >= 1<<63 { - return strconv.FormatUint(u64, 10), nil - } - return int64(u64), nil - case reflect.Float32, reflect.Float64: - return rv.Float(), nil - } - return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/LICENSE deleted file mode 100644 index 6600f1c..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/LICENSE +++ /dev/null @@ -1,165 +0,0 @@ -GNU LESSER GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 - - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - - This version of the GNU Lesser General Public License incorporates -the terms and conditions of version 3 of the GNU General Public -License, supplemented by the additional permissions listed below. - - 0. Additional Definitions. - - As used herein, "this License" refers to version 3 of the GNU Lesser -General Public License, and the "GNU GPL" refers to version 3 of the GNU -General Public License. - - "The Library" refers to a covered work governed by this License, -other than an Application or a Combined Work as defined below. - - An "Application" is any work that makes use of an interface provided -by the Library, but which is not otherwise based on the Library. -Defining a subclass of a class defined by the Library is deemed a mode -of using an interface provided by the Library. - - A "Combined Work" is a work produced by combining or linking an -Application with the Library. The particular version of the Library -with which the Combined Work was made is also called the "Linked -Version". - - The "Minimal Corresponding Source" for a Combined Work means the -Corresponding Source for the Combined Work, excluding any source code -for portions of the Combined Work that, considered in isolation, are -based on the Application, and not on the Linked Version. - - The "Corresponding Application Code" for a Combined Work means the -object code and/or source code for the Application, including any data -and utility programs needed for reproducing the Combined Work from the -Application, but excluding the System Libraries of the Combined Work. - - 1. Exception to Section 3 of the GNU GPL. - - You may convey a covered work under sections 3 and 4 of this License -without being bound by section 3 of the GNU GPL. - - 2. Conveying Modified Versions. - - If you modify a copy of the Library, and, in your modifications, a -facility refers to a function or data to be supplied by an Application -that uses the facility (other than as an argument passed when the -facility is invoked), then you may convey a copy of the modified -version: - - a) under this License, provided that you make a good faith effort to - ensure that, in the event an Application does not supply the - function or data, the facility still operates, and performs - whatever part of its purpose remains meaningful, or - - b) under the GNU GPL, with none of the additional permissions of - this License applicable to that copy. - - 3. Object Code Incorporating Material from Library Header Files. - - The object code form of an Application may incorporate material from -a header file that is part of the Library. You may convey such object -code under terms of your choice, provided that, if the incorporated -material is not limited to numerical parameters, data structure -layouts and accessors, or small macros, inline functions and templates -(ten or fewer lines in length), you do both of the following: - - a) Give prominent notice with each copy of the object code that the - Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the object code with a copy of the GNU GPL and this license - document. - - 4. Combined Works. - - You may convey a Combined Work under terms of your choice that, -taken together, effectively do not restrict modification of the -portions of the Library contained in the Combined Work and reverse -engineering for debugging such modifications, if you also do each of -the following: - - a) Give prominent notice with each copy of the Combined Work that - the Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the Combined Work with a copy of the GNU GPL and this license - document. - - c) For a Combined Work that displays copyright notices during - execution, include the copyright notice for the Library among - these notices, as well as a reference directing the user to the - copies of the GNU GPL and this license document. - - d) Do one of the following: - - 0) Convey the Minimal Corresponding Source under the terms of this - License, and the Corresponding Application Code in a form - suitable for, and under terms that permit, the user to - recombine or relink the Application with a modified version of - the Linked Version to produce a modified Combined Work, in the - manner specified by section 6 of the GNU GPL for conveying - Corresponding Source. - - 1) Use a suitable shared library mechanism for linking with the - Library. A suitable mechanism is one that (a) uses at run time - a copy of the Library already present on the user's computer - system, and (b) will operate properly with a modified version - of the Library that is interface-compatible with the Linked - Version. - - e) Provide Installation Information, but only if you would otherwise - be required to provide such information under section 6 of the - GNU GPL, and only to the extent that such information is - necessary to install and execute a modified version of the - Combined Work produced by recombining or relinking the - Application with a modified version of the Linked Version. (If - you use option 4d0, the Installation Information must accompany - the Minimal Corresponding Source and Corresponding Application - Code. If you use option 4d1, you must provide the Installation - Information in the manner specified by section 6 of the GNU GPL - for conveying Corresponding Source.) - - 5. Combined Libraries. - - You may place library facilities that are a work based on the -Library side by side in a single library together with other library -facilities that are not Applications and are not covered by this -License, and convey such a combined library under terms of your -choice, if you do both of the following: - - a) Accompany the combined library with a copy of the same work based - on the Library, uncombined with any other library facilities, - conveyed under the terms of this License. - - b) Give prominent notice with the combined library that part of it - is a work based on the Library, and explaining where to find the - accompanying uncombined form of the same work. - - 6. Revised Versions of the GNU Lesser General Public License. - - The Free Software Foundation may publish revised and/or new versions -of the GNU Lesser General Public License from time to time. Such new -versions will be similar in spirit to the present version, but may -differ in detail to address new problems or concerns. - - Each version is given a distinguishing version number. If the -Library as you received it specifies that a certain numbered version -of the GNU Lesser General Public License "or any later version" -applies to it, you have the option of following the terms and -conditions either of that published version or of any later version -published by the Free Software Foundation. If the Library as you -received it does not specify a version number of the GNU Lesser -General Public License, you may choose any version of the GNU Lesser -General Public License ever published by the Free Software Foundation. - - If the Library as you received it specifies that a proxy can decide -whether future versions of the GNU Lesser General Public License shall -apply, that proxy's public statement of acceptance of any version is -permanent authorization for you to choose that version for the -Library. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_unix.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_unix.go deleted file mode 100644 index 37f407d..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_unix.go +++ /dev/null @@ -1,18 +0,0 @@ -// +build freebsd openbsd netbsd dragonfly darwin linux - -package log - -import ( - "log" - "os" - "syscall" -) - -func CrashLog(file string) { - f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - log.Println(err.Error()) - } else { - syscall.Dup2(int(f.Fd()), 2) - } -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_win.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_win.go deleted file mode 100644 index 7d612ee..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_win.go +++ /dev/null @@ -1,37 +0,0 @@ -// +build windows - -package log - -import ( - "log" - "os" - "syscall" -) - -var ( - kernel32 = syscall.MustLoadDLL("kernel32.dll") - procSetStdHandle = kernel32.MustFindProc("SetStdHandle") -) - -func setStdHandle(stdhandle int32, handle syscall.Handle) error { - r0, _, e1 := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdhandle), uintptr(handle), 0) - if r0 == 0 { - if e1 != 0 { - return error(e1) - } - return syscall.EINVAL - } - return nil -} - -func CrashLog(file string) { - f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - log.Println(err.Error()) - } else { - err = setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(f.Fd())) - if err != nil { - log.Println(err.Error()) - } - } -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/log.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/log.go deleted file mode 100644 index 896b393..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/log.go +++ /dev/null @@ -1,380 +0,0 @@ -//high level log wrapper, so it can output different log based on level -package log - -import ( - "fmt" - "io" - "log" - "os" - "runtime" - "sync" - "time" -) - -const ( - Ldate = log.Ldate - Llongfile = log.Llongfile - Lmicroseconds = log.Lmicroseconds - Lshortfile = log.Lshortfile - LstdFlags = log.LstdFlags - Ltime = log.Ltime -) - -type ( - LogLevel int - LogType int -) - -const ( - LOG_FATAL = LogType(0x1) - LOG_ERROR = LogType(0x2) - LOG_WARNING = LogType(0x4) - LOG_INFO = LogType(0x8) - LOG_DEBUG = LogType(0x10) -) - -const ( - LOG_LEVEL_NONE = LogLevel(0x0) - LOG_LEVEL_FATAL = LOG_LEVEL_NONE | LogLevel(LOG_FATAL) - LOG_LEVEL_ERROR = LOG_LEVEL_FATAL | LogLevel(LOG_ERROR) - LOG_LEVEL_WARN = LOG_LEVEL_ERROR | LogLevel(LOG_WARNING) - LOG_LEVEL_INFO = LOG_LEVEL_WARN | LogLevel(LOG_INFO) - LOG_LEVEL_DEBUG = LOG_LEVEL_INFO | LogLevel(LOG_DEBUG) - LOG_LEVEL_ALL = LOG_LEVEL_DEBUG -) - -const FORMAT_TIME_DAY string = "20060102" -const FORMAT_TIME_HOUR string = "2006010215" - -var _log *logger = New() - -func init() { - SetFlags(Ldate | Ltime | Lshortfile) - SetHighlighting(runtime.GOOS != "windows") -} - -func Logger() *log.Logger { - return _log._log -} - -func SetLevel(level LogLevel) { - _log.SetLevel(level) -} -func GetLogLevel() LogLevel { - return _log.level -} - -func SetOutput(out io.Writer) { - _log.SetOutput(out) -} - -func SetOutputByName(path string) error { - return _log.SetOutputByName(path) -} - -func SetFlags(flags int) { - _log._log.SetFlags(flags) -} - -func Info(v ...interface{}) { - _log.Info(v...) -} - -func Infof(format string, v ...interface{}) { - _log.Infof(format, v...) -} - -func Debug(v ...interface{}) { - _log.Debug(v...) -} - -func Debugf(format string, v ...interface{}) { - _log.Debugf(format, v...) -} - -func Warn(v ...interface{}) { - _log.Warning(v...) -} - -func Warnf(format string, v ...interface{}) { - _log.Warningf(format, v...) -} - -func Warning(v ...interface{}) { - _log.Warning(v...) -} - -func Warningf(format string, v ...interface{}) { - _log.Warningf(format, v...) -} - -func Error(v ...interface{}) { - _log.Error(v...) -} - -func Errorf(format string, v ...interface{}) { - _log.Errorf(format, v...) -} - -func Fatal(v ...interface{}) { - _log.Fatal(v...) -} - -func Fatalf(format string, v ...interface{}) { - _log.Fatalf(format, v...) -} - -func SetLevelByString(level string) { - _log.SetLevelByString(level) -} - -func SetHighlighting(highlighting bool) { - _log.SetHighlighting(highlighting) -} - -func SetRotateByDay() { - _log.SetRotateByDay() -} - -func SetRotateByHour() { - _log.SetRotateByHour() -} - -type logger struct { - _log *log.Logger - level LogLevel - highlighting bool - - dailyRolling bool - hourRolling bool - - fileName string - logSuffix string - fd *os.File - - lock sync.Mutex -} - -func (l *logger) SetHighlighting(highlighting bool) { - l.highlighting = highlighting -} - -func (l *logger) SetLevel(level LogLevel) { - l.level = level -} - -func (l *logger) SetLevelByString(level string) { - l.level = StringToLogLevel(level) -} - -func (l *logger) SetRotateByDay() { - l.dailyRolling = true - l.logSuffix = genDayTime(time.Now()) -} - -func (l *logger) SetRotateByHour() { - l.hourRolling = true - l.logSuffix = genHourTime(time.Now()) -} - -func (l *logger) rotate() error { - l.lock.Lock() - defer l.lock.Unlock() - - var suffix string - if l.dailyRolling { - suffix = genDayTime(time.Now()) - } else if l.hourRolling { - suffix = genHourTime(time.Now()) - } else { - return nil - } - - // Notice: if suffix is not equal to l.LogSuffix, then rotate - if suffix != l.logSuffix { - err := l.doRotate(suffix) - if err != nil { - return err - } - } - - return nil -} - -func (l *logger) doRotate(suffix string) error { - // Notice: Not check error, is this ok? - l.fd.Close() - - lastFileName := l.fileName + "." + l.logSuffix - err := os.Rename(l.fileName, lastFileName) - if err != nil { - return err - } - - err = l.SetOutputByName(l.fileName) - if err != nil { - return err - } - - l.logSuffix = suffix - - return nil -} - -func (l *logger) SetOutput(out io.Writer) { - l._log = log.New(out, l._log.Prefix(), l._log.Flags()) -} - -func (l *logger) SetOutputByName(path string) error { - f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666) - if err != nil { - log.Fatal(err) - } - - l.SetOutput(f) - - l.fileName = path - l.fd = f - - return err -} - -func (l *logger) log(t LogType, v ...interface{}) { - if l.level|LogLevel(t) != l.level { - return - } - - err := l.rotate() - if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err.Error()) - return - } - - v1 := make([]interface{}, len(v)+2) - logStr, logColor := LogTypeToString(t) - if l.highlighting { - v1[0] = "\033" + logColor + "m[" + logStr + "]" - copy(v1[1:], v) - v1[len(v)+1] = "\033[0m" - } else { - v1[0] = "[" + logStr + "]" - copy(v1[1:], v) - v1[len(v)+1] = "" - } - - s := fmt.Sprintln(v1...) - l._log.Output(4, s) -} - -func (l *logger) logf(t LogType, format string, v ...interface{}) { - if l.level|LogLevel(t) != l.level { - return - } - - err := l.rotate() - if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err.Error()) - return - } - - logStr, logColor := LogTypeToString(t) - var s string - if l.highlighting { - s = "\033" + logColor + "m[" + logStr + "] " + fmt.Sprintf(format, v...) + "\033[0m" - } else { - s = "[" + logStr + "] " + fmt.Sprintf(format, v...) - } - l._log.Output(4, s) -} - -func (l *logger) Fatal(v ...interface{}) { - l.log(LOG_FATAL, v...) - os.Exit(-1) -} - -func (l *logger) Fatalf(format string, v ...interface{}) { - l.logf(LOG_FATAL, format, v...) - os.Exit(-1) -} - -func (l *logger) Error(v ...interface{}) { - l.log(LOG_ERROR, v...) -} - -func (l *logger) Errorf(format string, v ...interface{}) { - l.logf(LOG_ERROR, format, v...) -} - -func (l *logger) Warning(v ...interface{}) { - l.log(LOG_WARNING, v...) -} - -func (l *logger) Warningf(format string, v ...interface{}) { - l.logf(LOG_WARNING, format, v...) -} - -func (l *logger) Debug(v ...interface{}) { - l.log(LOG_DEBUG, v...) -} - -func (l *logger) Debugf(format string, v ...interface{}) { - l.logf(LOG_DEBUG, format, v...) -} - -func (l *logger) Info(v ...interface{}) { - l.log(LOG_INFO, v...) -} - -func (l *logger) Infof(format string, v ...interface{}) { - l.logf(LOG_INFO, format, v...) -} - -func StringToLogLevel(level string) LogLevel { - switch level { - case "fatal": - return LOG_LEVEL_FATAL - case "error": - return LOG_LEVEL_ERROR - case "warn": - return LOG_LEVEL_WARN - case "warning": - return LOG_LEVEL_WARN - case "debug": - return LOG_LEVEL_DEBUG - case "info": - return LOG_LEVEL_INFO - } - return LOG_LEVEL_ALL -} - -func LogTypeToString(t LogType) (string, string) { - switch t { - case LOG_FATAL: - return "fatal", "[0;31" - case LOG_ERROR: - return "error", "[0;31" - case LOG_WARNING: - return "warning", "[0;33" - case LOG_DEBUG: - return "debug", "[0;36" - case LOG_INFO: - return "info", "[0;37" - } - return "unknown", "[0;37" -} - -func genDayTime(t time.Time) string { - return t.Format(FORMAT_TIME_DAY) -} - -func genHourTime(t time.Time) string { - return t.Format(FORMAT_TIME_HOUR) -} - -func New() *logger { - return Newlogger(os.Stderr, "") -} - -func Newlogger(w io.Writer, prefix string) *logger { - return &logger{_log: log.New(w, prefix, LstdFlags), level: LOG_LEVEL_ALL, highlighting: true} -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/uuid.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/uuid.go deleted file mode 100644 index 9c7fbaa..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/uuid.go +++ /dev/null @@ -1,488 +0,0 @@ -// Copyright (C) 2013-2015 by Maxim Bublis -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -// Package uuid provides implementation of Universally Unique Identifier (UUID). -// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and -// version 2 (as specified in DCE 1.1). -package uuid - -import ( - "bytes" - "crypto/md5" - "crypto/rand" - "crypto/sha1" - "database/sql/driver" - "encoding/binary" - "encoding/hex" - "fmt" - "hash" - "net" - "os" - "sync" - "time" -) - -// UUID layout variants. -const ( - VariantNCS = iota - VariantRFC4122 - VariantMicrosoft - VariantFuture -) - -// UUID DCE domains. -const ( - DomainPerson = iota - DomainGroup - DomainOrg -) - -// Difference in 100-nanosecond intervals between -// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970). -const epochStart = 122192928000000000 - -// Used in string method conversion -const dash byte = '-' - -// UUID v1/v2 storage. -var ( - storageMutex sync.Mutex - storageOnce sync.Once - epochFunc = unixTimeFunc - clockSequence uint16 - lastTime uint64 - hardwareAddr [6]byte - posixUID = uint32(os.Getuid()) - posixGID = uint32(os.Getgid()) -) - -// String parse helpers. -var ( - urnPrefix = []byte("urn:uuid:") - byteGroups = []int{8, 4, 4, 4, 12} -) - -func initClockSequence() { - buf := make([]byte, 2) - safeRandom(buf) - clockSequence = binary.BigEndian.Uint16(buf) -} - -func initHardwareAddr() { - interfaces, err := net.Interfaces() - if err == nil { - for _, iface := range interfaces { - if len(iface.HardwareAddr) >= 6 { - copy(hardwareAddr[:], iface.HardwareAddr) - return - } - } - } - - // Initialize hardwareAddr randomly in case - // of real network interfaces absence - safeRandom(hardwareAddr[:]) - - // Set multicast bit as recommended in RFC 4122 - hardwareAddr[0] |= 0x01 -} - -func initStorage() { - initClockSequence() - initHardwareAddr() -} - -func safeRandom(dest []byte) { - if _, err := rand.Read(dest); err != nil { - panic(err) - } -} - -// Returns difference in 100-nanosecond intervals between -// UUID epoch (October 15, 1582) and current time. -// This is default epoch calculation function. -func unixTimeFunc() uint64 { - return epochStart + uint64(time.Now().UnixNano()/100) -} - -// UUID representation compliant with specification -// described in RFC 4122. -type UUID [16]byte - -// NullUUID can be used with the standard sql package to represent a -// UUID value that can be NULL in the database -type NullUUID struct { - UUID UUID - Valid bool -} - -// The nil UUID is special form of UUID that is specified to have all -// 128 bits set to zero. -var Nil = UUID{} - -// Predefined namespace UUIDs. -var ( - NamespaceDNS, _ = FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") - NamespaceURL, _ = FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8") - NamespaceOID, _ = FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8") - NamespaceX500, _ = FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8") -) - -// And returns result of binary AND of two UUIDs. -func And(u1 UUID, u2 UUID) UUID { - u := UUID{} - for i := 0; i < 16; i++ { - u[i] = u1[i] & u2[i] - } - return u -} - -// Or returns result of binary OR of two UUIDs. -func Or(u1 UUID, u2 UUID) UUID { - u := UUID{} - for i := 0; i < 16; i++ { - u[i] = u1[i] | u2[i] - } - return u -} - -// Equal returns true if u1 and u2 equals, otherwise returns false. -func Equal(u1 UUID, u2 UUID) bool { - return bytes.Equal(u1[:], u2[:]) -} - -// Version returns algorithm version used to generate UUID. -func (u UUID) Version() uint { - return uint(u[6] >> 4) -} - -// Variant returns UUID layout variant. -func (u UUID) Variant() uint { - switch { - case (u[8] & 0x80) == 0x00: - return VariantNCS - case (u[8]&0xc0)|0x80 == 0x80: - return VariantRFC4122 - case (u[8]&0xe0)|0xc0 == 0xc0: - return VariantMicrosoft - } - return VariantFuture -} - -// Bytes returns bytes slice representation of UUID. -func (u UUID) Bytes() []byte { - return u[:] -} - -// Returns canonical string representation of UUID: -// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. -func (u UUID) String() string { - buf := make([]byte, 36) - - hex.Encode(buf[0:8], u[0:4]) - buf[8] = dash - hex.Encode(buf[9:13], u[4:6]) - buf[13] = dash - hex.Encode(buf[14:18], u[6:8]) - buf[18] = dash - hex.Encode(buf[19:23], u[8:10]) - buf[23] = dash - hex.Encode(buf[24:], u[10:]) - - return string(buf) -} - -// SetVersion sets version bits. -func (u *UUID) SetVersion(v byte) { - u[6] = (u[6] & 0x0f) | (v << 4) -} - -// SetVariant sets variant bits as described in RFC 4122. -func (u *UUID) SetVariant() { - u[8] = (u[8] & 0xbf) | 0x80 -} - -// MarshalText implements the encoding.TextMarshaler interface. -// The encoding is the same as returned by String. -func (u UUID) MarshalText() (text []byte, err error) { - text = []byte(u.String()) - return -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// Following formats are supported: -// "6ba7b810-9dad-11d1-80b4-00c04fd430c8", -// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}", -// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" -func (u *UUID) UnmarshalText(text []byte) (err error) { - if len(text) < 32 { - err = fmt.Errorf("uuid: UUID string too short: %s", text) - return - } - - t := text[:] - braced := false - - if bytes.Equal(t[:9], urnPrefix) { - t = t[9:] - } else if t[0] == '{' { - braced = true - t = t[1:] - } - - b := u[:] - - for i, byteGroup := range byteGroups { - if i > 0 && t[0] == '-' { - t = t[1:] - } else if i > 0 && t[0] != '-' { - err = fmt.Errorf("uuid: invalid string format") - return - } - - if i == 2 { - if !bytes.Contains([]byte("012345"), []byte{t[0]}) { - err = fmt.Errorf("uuid: invalid version number: %s", t[0]) - return - } - } - - if len(t) < byteGroup { - err = fmt.Errorf("uuid: UUID string too short: %s", text) - return - } - - if i == 4 && len(t) > byteGroup && - ((braced && t[byteGroup] != '}') || len(t[byteGroup:]) > 1 || !braced) { - err = fmt.Errorf("uuid: UUID string too long: %s", t) - return - } - - _, err = hex.Decode(b[:byteGroup/2], t[:byteGroup]) - - if err != nil { - return - } - - t = t[byteGroup:] - b = b[byteGroup/2:] - } - - return -} - -// MarshalBinary implements the encoding.BinaryMarshaler interface. -func (u UUID) MarshalBinary() (data []byte, err error) { - data = u.Bytes() - return -} - -// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. -// It will return error if the slice isn't 16 bytes long. -func (u *UUID) UnmarshalBinary(data []byte) (err error) { - if len(data) != 16 { - err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data)) - return - } - copy(u[:], data) - - return -} - -// Value implements the driver.Valuer interface. -func (u UUID) Value() (driver.Value, error) { - return u.String(), nil -} - -// Scan implements the sql.Scanner interface. -// A 16-byte slice is handled by UnmarshalBinary, while -// a longer byte slice or a string is handled by UnmarshalText. -func (u *UUID) Scan(src interface{}) error { - switch src := src.(type) { - case []byte: - if len(src) == 16 { - return u.UnmarshalBinary(src) - } - return u.UnmarshalText(src) - - case string: - return u.UnmarshalText([]byte(src)) - } - - return fmt.Errorf("uuid: cannot convert %T to UUID", src) -} - -// Value implements the driver.Valuer interface. -func (u NullUUID) Value() (driver.Value, error) { - if !u.Valid { - return nil, nil - } - // Delegate to UUID Value function - return u.UUID.Value() -} - -// Scan implements the sql.Scanner interface. -func (u *NullUUID) Scan(src interface{}) error { - if src == nil { - u.UUID, u.Valid = Nil, false - return nil - } - - // Delegate to UUID Scan function - u.Valid = true - return u.UUID.Scan(src) -} - -// FromBytes returns UUID converted from raw byte slice input. -// It will return error if the slice isn't 16 bytes long. -func FromBytes(input []byte) (u UUID, err error) { - err = u.UnmarshalBinary(input) - return -} - -// FromBytesOrNil returns UUID converted from raw byte slice input. -// Same behavior as FromBytes, but returns a Nil UUID on error. -func FromBytesOrNil(input []byte) UUID { - uuid, err := FromBytes(input) - if err != nil { - return Nil - } - return uuid -} - -// FromString returns UUID parsed from string input. -// Input is expected in a form accepted by UnmarshalText. -func FromString(input string) (u UUID, err error) { - err = u.UnmarshalText([]byte(input)) - return -} - -// FromStringOrNil returns UUID parsed from string input. -// Same behavior as FromString, but returns a Nil UUID on error. -func FromStringOrNil(input string) UUID { - uuid, err := FromString(input) - if err != nil { - return Nil - } - return uuid -} - -// Returns UUID v1/v2 storage state. -// Returns epoch timestamp, clock sequence, and hardware address. -func getStorage() (uint64, uint16, []byte) { - storageOnce.Do(initStorage) - - storageMutex.Lock() - defer storageMutex.Unlock() - - timeNow := epochFunc() - // Clock changed backwards since last UUID generation. - // Should increase clock sequence. - if timeNow <= lastTime { - clockSequence++ - } - lastTime = timeNow - - return timeNow, clockSequence, hardwareAddr[:] -} - -// NewV1 returns UUID based on current timestamp and MAC address. -func NewV1() UUID { - u := UUID{} - - timeNow, clockSeq, hardwareAddr := getStorage() - - binary.BigEndian.PutUint32(u[0:], uint32(timeNow)) - binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) - binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) - binary.BigEndian.PutUint16(u[8:], clockSeq) - - copy(u[10:], hardwareAddr) - - u.SetVersion(1) - u.SetVariant() - - return u -} - -// NewV2 returns DCE Security UUID based on POSIX UID/GID. -func NewV2(domain byte) UUID { - u := UUID{} - - timeNow, clockSeq, hardwareAddr := getStorage() - - switch domain { - case DomainPerson: - binary.BigEndian.PutUint32(u[0:], posixUID) - case DomainGroup: - binary.BigEndian.PutUint32(u[0:], posixGID) - } - - binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) - binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) - binary.BigEndian.PutUint16(u[8:], clockSeq) - u[9] = domain - - copy(u[10:], hardwareAddr) - - u.SetVersion(2) - u.SetVariant() - - return u -} - -// NewV3 returns UUID based on MD5 hash of namespace UUID and name. -func NewV3(ns UUID, name string) UUID { - u := newFromHash(md5.New(), ns, name) - u.SetVersion(3) - u.SetVariant() - - return u -} - -// NewV4 returns random generated UUID. -func NewV4() UUID { - u := UUID{} - safeRandom(u[:]) - u.SetVersion(4) - u.SetVariant() - - return u -} - -// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name. -func NewV5(ns UUID, name string) UUID { - u := newFromHash(sha1.New(), ns, name) - u.SetVersion(5) - u.SetVariant() - - return u -} - -// Returns UUID based on hashing of namespace UUID and name. -func newFromHash(h hash.Hash, ns UUID, name string) UUID { - u := UUID{} - h.Write(ns[:]) - h.Write([]byte(name)) - copy(u[:], h.Sum(nil)) - - return u -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/ioutil.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/ioutil.go deleted file mode 100644 index c99c987..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/ioutil.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2012, Google Inc. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package ioutil2 - -import ( - "io" - "io/ioutil" - "os" - "path" -) - -// Write file to temp and atomically move when everything else succeeds. -func WriteFileAtomic(filename string, data []byte, perm os.FileMode) error { - dir, name := path.Split(filename) - f, err := ioutil.TempFile(dir, name) - if err != nil { - return err - } - n, err := f.Write(data) - f.Close() - if err == nil && n < len(data) { - err = io.ErrShortWrite - } else { - err = os.Chmod(f.Name(), perm) - } - if err != nil { - os.Remove(f.Name()) - return err - } - return os.Rename(f.Name(), filename) -} - -// Check file exists or not -func FileExists(name string) bool { - _, err := os.Stat(name) - return !os.IsNotExist(err) -} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/sectionwriter.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/sectionwriter.go deleted file mode 100644 index c02ab0d..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/sectionwriter.go +++ /dev/null @@ -1,69 +0,0 @@ -package ioutil2 - -import ( - "errors" - "io" -) - -var ErrExceedLimit = errors.New("write exceed limit") - -func NewSectionWriter(w io.WriterAt, off int64, n int64) *SectionWriter { - return &SectionWriter{w, off, off, off + n} -} - -type SectionWriter struct { - w io.WriterAt - base int64 - off int64 - limit int64 -} - -func (s *SectionWriter) Write(p []byte) (n int, err error) { - if s.off >= s.limit { - return 0, ErrExceedLimit - } - - if max := s.limit - s.off; int64(len(p)) > max { - return 0, ErrExceedLimit - } - - n, err = s.w.WriteAt(p, s.off) - s.off += int64(n) - return -} - -var errWhence = errors.New("Seek: invalid whence") -var errOffset = errors.New("Seek: invalid offset") - -func (s *SectionWriter) Seek(offset int64, whence int) (int64, error) { - switch whence { - default: - return 0, errWhence - case 0: - offset += s.base - case 1: - offset += s.off - case 2: - offset += s.limit - } - if offset < s.base { - return 0, errOffset - } - s.off = offset - return offset - s.base, nil -} - -func (s *SectionWriter) WriteAt(p []byte, off int64) (n int, err error) { - if off < 0 || off >= s.limit-s.base { - return 0, errOffset - } - off += s.base - if max := s.limit - off; int64(len(p)) > max { - return 0, ErrExceedLimit - } - - return s.w.WriteAt(p, off) -} - -// Size returns the size of the section in bytes. -func (s *SectionWriter) Size() int64 { return s.limit - s.base } diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/PATENTS b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/PATENTS deleted file mode 100644 index 7330990..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/PATENTS +++ /dev/null @@ -1,22 +0,0 @@ -Additional IP Rights Grant (Patents) - -"This implementation" means the copyrightable works distributed by -Google as part of the Go project. - -Google hereby grants to You a perpetual, worldwide, non-exclusive, -no-charge, royalty-free, irrevocable (except as stated in this section) -patent license to make, have made, use, offer to sell, sell, import, -transfer and otherwise run, modify and propagate the contents of this -implementation of Go, where such license applies only to those patent -claims, both currently owned or controlled by Google and acquired in -the future, licensable by Google that are necessarily infringed by this -implementation of Go. This grant does not include claims that would be -infringed only as a consequence of further modification of this -implementation. If you or your agent or exclusive licensee institute or -order or agree to the institution of patent litigation against any -entity (including a cross-claim or counterclaim in a lawsuit) alleging -that this implementation of Go or any code incorporated within this -implementation of Go constitutes direct or contributory patent -infringement, or inducement of patent infringement, then any patent -rights granted to you under this License for this implementation of Go -shall terminate as of the date such litigation is filed. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/context/context.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/context/context.go deleted file mode 100644 index 77b64d0..0000000 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/context/context.go +++ /dev/null @@ -1,447 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package context defines the Context type, which carries deadlines, -// cancelation signals, and other request-scoped values across API boundaries -// and between processes. -// -// Incoming requests to a server should create a Context, and outgoing calls to -// servers should accept a Context. The chain of function calls between must -// propagate the Context, optionally replacing it with a modified copy created -// using WithDeadline, WithTimeout, WithCancel, or WithValue. -// -// Programs that use Contexts should follow these rules to keep interfaces -// consistent across packages and enable static analysis tools to check context -// propagation: -// -// Do not store Contexts inside a struct type; instead, pass a Context -// explicitly to each function that needs it. The Context should be the first -// parameter, typically named ctx: -// -// func DoSomething(ctx context.Context, arg Arg) error { -// // ... use ctx ... -// } -// -// Do not pass a nil Context, even if a function permits it. Pass context.TODO -// if you are unsure about which Context to use. -// -// Use context Values only for request-scoped data that transits processes and -// APIs, not for passing optional parameters to functions. -// -// The same Context may be passed to functions running in different goroutines; -// Contexts are safe for simultaneous use by multiple goroutines. -// -// See http://blog.golang.org/context for example code for a server that uses -// Contexts. -package context // import "golang.org/x/net/context" - -import ( - "errors" - "fmt" - "sync" - "time" -) - -// A Context carries a deadline, a cancelation signal, and other values across -// API boundaries. -// -// Context's methods may be called by multiple goroutines simultaneously. -type Context interface { - // Deadline returns the time when work done on behalf of this context - // should be canceled. Deadline returns ok==false when no deadline is - // set. Successive calls to Deadline return the same results. - Deadline() (deadline time.Time, ok bool) - - // Done returns a channel that's closed when work done on behalf of this - // context should be canceled. Done may return nil if this context can - // never be canceled. Successive calls to Done return the same value. - // - // WithCancel arranges for Done to be closed when cancel is called; - // WithDeadline arranges for Done to be closed when the deadline - // expires; WithTimeout arranges for Done to be closed when the timeout - // elapses. - // - // Done is provided for use in select statements: - // - // // Stream generates values with DoSomething and sends them to out - // // until DoSomething returns an error or ctx.Done is closed. - // func Stream(ctx context.Context, out <-chan Value) error { - // for { - // v, err := DoSomething(ctx) - // if err != nil { - // return err - // } - // select { - // case <-ctx.Done(): - // return ctx.Err() - // case out <- v: - // } - // } - // } - // - // See http://blog.golang.org/pipelines for more examples of how to use - // a Done channel for cancelation. - Done() <-chan struct{} - - // Err returns a non-nil error value after Done is closed. Err returns - // Canceled if the context was canceled or DeadlineExceeded if the - // context's deadline passed. No other values for Err are defined. - // After Done is closed, successive calls to Err return the same value. - Err() error - - // Value returns the value associated with this context for key, or nil - // if no value is associated with key. Successive calls to Value with - // the same key returns the same result. - // - // Use context values only for request-scoped data that transits - // processes and API boundaries, not for passing optional parameters to - // functions. - // - // A key identifies a specific value in a Context. Functions that wish - // to store values in Context typically allocate a key in a global - // variable then use that key as the argument to context.WithValue and - // Context.Value. A key can be any type that supports equality; - // packages should define keys as an unexported type to avoid - // collisions. - // - // Packages that define a Context key should provide type-safe accessors - // for the values stores using that key: - // - // // Package user defines a User type that's stored in Contexts. - // package user - // - // import "golang.org/x/net/context" - // - // // User is the type of value stored in the Contexts. - // type User struct {...} - // - // // key is an unexported type for keys defined in this package. - // // This prevents collisions with keys defined in other packages. - // type key int - // - // // userKey is the key for user.User values in Contexts. It is - // // unexported; clients use user.NewContext and user.FromContext - // // instead of using this key directly. - // var userKey key = 0 - // - // // NewContext returns a new Context that carries value u. - // func NewContext(ctx context.Context, u *User) context.Context { - // return context.WithValue(ctx, userKey, u) - // } - // - // // FromContext returns the User value stored in ctx, if any. - // func FromContext(ctx context.Context) (*User, bool) { - // u, ok := ctx.Value(userKey).(*User) - // return u, ok - // } - Value(key interface{}) interface{} -} - -// Canceled is the error returned by Context.Err when the context is canceled. -var Canceled = errors.New("context canceled") - -// DeadlineExceeded is the error returned by Context.Err when the context's -// deadline passes. -var DeadlineExceeded = errors.New("context deadline exceeded") - -// An emptyCtx is never canceled, has no values, and has no deadline. It is not -// struct{}, since vars of this type must have distinct addresses. -type emptyCtx int - -func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { - return -} - -func (*emptyCtx) Done() <-chan struct{} { - return nil -} - -func (*emptyCtx) Err() error { - return nil -} - -func (*emptyCtx) Value(key interface{}) interface{} { - return nil -} - -func (e *emptyCtx) String() string { - switch e { - case background: - return "context.Background" - case todo: - return "context.TODO" - } - return "unknown empty Context" -} - -var ( - background = new(emptyCtx) - todo = new(emptyCtx) -) - -// Background returns a non-nil, empty Context. It is never canceled, has no -// values, and has no deadline. It is typically used by the main function, -// initialization, and tests, and as the top-level Context for incoming -// requests. -func Background() Context { - return background -} - -// TODO returns a non-nil, empty Context. Code should use context.TODO when -// it's unclear which Context to use or it is not yet available (because the -// surrounding function has not yet been extended to accept a Context -// parameter). TODO is recognized by static analysis tools that determine -// whether Contexts are propagated correctly in a program. -func TODO() Context { - return todo -} - -// A CancelFunc tells an operation to abandon its work. -// A CancelFunc does not wait for the work to stop. -// After the first call, subsequent calls to a CancelFunc do nothing. -type CancelFunc func() - -// WithCancel returns a copy of parent with a new Done channel. The returned -// context's Done channel is closed when the returned cancel function is called -// or when the parent context's Done channel is closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { - c := newCancelCtx(parent) - propagateCancel(parent, &c) - return &c, func() { c.cancel(true, Canceled) } -} - -// newCancelCtx returns an initialized cancelCtx. -func newCancelCtx(parent Context) cancelCtx { - return cancelCtx{ - Context: parent, - done: make(chan struct{}), - } -} - -// propagateCancel arranges for child to be canceled when parent is. -func propagateCancel(parent Context, child canceler) { - if parent.Done() == nil { - return // parent is never canceled - } - if p, ok := parentCancelCtx(parent); ok { - p.mu.Lock() - if p.err != nil { - // parent has already been canceled - child.cancel(false, p.err) - } else { - if p.children == nil { - p.children = make(map[canceler]bool) - } - p.children[child] = true - } - p.mu.Unlock() - } else { - go func() { - select { - case <-parent.Done(): - child.cancel(false, parent.Err()) - case <-child.Done(): - } - }() - } -} - -// parentCancelCtx follows a chain of parent references until it finds a -// *cancelCtx. This function understands how each of the concrete types in this -// package represents its parent. -func parentCancelCtx(parent Context) (*cancelCtx, bool) { - for { - switch c := parent.(type) { - case *cancelCtx: - return c, true - case *timerCtx: - return &c.cancelCtx, true - case *valueCtx: - parent = c.Context - default: - return nil, false - } - } -} - -// removeChild removes a context from its parent. -func removeChild(parent Context, child canceler) { - p, ok := parentCancelCtx(parent) - if !ok { - return - } - p.mu.Lock() - if p.children != nil { - delete(p.children, child) - } - p.mu.Unlock() -} - -// A canceler is a context type that can be canceled directly. The -// implementations are *cancelCtx and *timerCtx. -type canceler interface { - cancel(removeFromParent bool, err error) - Done() <-chan struct{} -} - -// A cancelCtx can be canceled. When canceled, it also cancels any children -// that implement canceler. -type cancelCtx struct { - Context - - done chan struct{} // closed by the first cancel call. - - mu sync.Mutex - children map[canceler]bool // set to nil by the first cancel call - err error // set to non-nil by the first cancel call -} - -func (c *cancelCtx) Done() <-chan struct{} { - return c.done -} - -func (c *cancelCtx) Err() error { - c.mu.Lock() - defer c.mu.Unlock() - return c.err -} - -func (c *cancelCtx) String() string { - return fmt.Sprintf("%v.WithCancel", c.Context) -} - -// cancel closes c.done, cancels each of c's children, and, if -// removeFromParent is true, removes c from its parent's children. -func (c *cancelCtx) cancel(removeFromParent bool, err error) { - if err == nil { - panic("context: internal error: missing cancel error") - } - c.mu.Lock() - if c.err != nil { - c.mu.Unlock() - return // already canceled - } - c.err = err - close(c.done) - for child := range c.children { - // NOTE: acquiring the child's lock while holding parent's lock. - child.cancel(false, err) - } - c.children = nil - c.mu.Unlock() - - if removeFromParent { - removeChild(c.Context, c) - } -} - -// WithDeadline returns a copy of the parent context with the deadline adjusted -// to be no later than d. If the parent's deadline is already earlier than d, -// WithDeadline(parent, d) is semantically equivalent to parent. The returned -// context's Done channel is closed when the deadline expires, when the returned -// cancel function is called, or when the parent context's Done channel is -// closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { - if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { - // The current deadline is already sooner than the new one. - return WithCancel(parent) - } - c := &timerCtx{ - cancelCtx: newCancelCtx(parent), - deadline: deadline, - } - propagateCancel(parent, c) - d := deadline.Sub(time.Now()) - if d <= 0 { - c.cancel(true, DeadlineExceeded) // deadline has already passed - return c, func() { c.cancel(true, Canceled) } - } - c.mu.Lock() - defer c.mu.Unlock() - if c.err == nil { - c.timer = time.AfterFunc(d, func() { - c.cancel(true, DeadlineExceeded) - }) - } - return c, func() { c.cancel(true, Canceled) } -} - -// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to -// implement Done and Err. It implements cancel by stopping its timer then -// delegating to cancelCtx.cancel. -type timerCtx struct { - cancelCtx - timer *time.Timer // Under cancelCtx.mu. - - deadline time.Time -} - -func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { - return c.deadline, true -} - -func (c *timerCtx) String() string { - return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) -} - -func (c *timerCtx) cancel(removeFromParent bool, err error) { - c.cancelCtx.cancel(false, err) - if removeFromParent { - // Remove this timerCtx from its parent cancelCtx's children. - removeChild(c.cancelCtx.Context, c) - } - c.mu.Lock() - if c.timer != nil { - c.timer.Stop() - c.timer = nil - } - c.mu.Unlock() -} - -// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete: -// -// func slowOperationWithTimeout(ctx context.Context) (Result, error) { -// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) -// defer cancel() // releases resources if slowOperation completes before timeout elapses -// return slowOperation(ctx) -// } -func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { - return WithDeadline(parent, time.Now().Add(timeout)) -} - -// WithValue returns a copy of parent in which the value associated with key is -// val. -// -// Use context Values only for request-scoped data that transits processes and -// APIs, not for passing optional parameters to functions. -func WithValue(parent Context, key interface{}, val interface{}) Context { - return &valueCtx{parent, key, val} -} - -// A valueCtx carries a key-value pair. It implements Value for that key and -// delegates all other calls to the embedded Context. -type valueCtx struct { - Context - key, val interface{} -} - -func (c *valueCtx) String() string { - return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) -} - -func (c *valueCtx) Value(key interface{}) interface{} { - if c.key == key { - return c.val - } - return c.Context.Value(key) -} diff --git a/vendor/github.com/siddontang/go-mysql/canal/canal.go b/vendor/github.com/siddontang/go-mysql/canal/canal.go index cdeaf98..64d9aec 100644 --- a/vendor/github.com/siddontang/go-mysql/canal/canal.go +++ b/vendor/github.com/siddontang/go-mysql/canal/canal.go @@ -1,26 +1,25 @@ package canal import ( + "context" "fmt" "io/ioutil" "os" - "path" + "regexp" "strconv" "strings" "sync" + "time" "github.com/juju/errors" - "github.com/ngaut/log" + "github.com/siddontang/go-log/log" "github.com/siddontang/go-mysql/client" "github.com/siddontang/go-mysql/dump" "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/replication" "github.com/siddontang/go-mysql/schema" - "github.com/siddontang/go/sync2" ) -var errCanalClosed = errors.New("canal was closed") - // Canal can sync your MySQL data into everywhere, like Elasticsearch, Redis, etc... // MySQL must open row format for binlog type Canal struct { @@ -30,48 +29,49 @@ type Canal struct { master *masterInfo dumper *dump.Dumper + dumped bool dumpDoneCh chan struct{} syncer *replication.BinlogSyncer - rsLock sync.Mutex - rsHandlers []RowsEventHandler + eventHandler EventHandler connLock sync.Mutex conn *client.Conn - wg sync.WaitGroup + tableLock sync.RWMutex + tables map[string]*schema.Table + errorTablesGetTime map[string]time.Time - tableLock sync.Mutex - tables map[string]*schema.Table + tableMatchCache map[string]bool + includeTableRegex []*regexp.Regexp + excludeTableRegex []*regexp.Regexp - quit chan struct{} - closed sync2.AtomicBool + ctx context.Context + cancel context.CancelFunc } +// canal will retry fetching unknown table's meta after UnknownTableRetryPeriod +var UnknownTableRetryPeriod = time.Second * time.Duration(10) +var ErrExcludedTable = errors.New("excluded table meta") + func NewCanal(cfg *Config) (*Canal, error) { c := new(Canal) c.cfg = cfg - c.closed.Set(false) - c.quit = make(chan struct{}) - os.MkdirAll(cfg.DataDir, 0755) + c.ctx, c.cancel = context.WithCancel(context.Background()) c.dumpDoneCh = make(chan struct{}) - c.rsHandlers = make([]RowsEventHandler, 0, 4) + c.eventHandler = &DummyEventHandler{} + c.tables = make(map[string]*schema.Table) + if c.cfg.DiscardNoMetaRowEvent { + c.errorTablesGetTime = make(map[string]time.Time) + } + c.master = &masterInfo{} var err error - if c.master, err = loadMasterInfo(c.masterInfoPath()); err != nil { - return nil, errors.Trace(err) - } else if len(c.master.Addr) != 0 && c.master.Addr != c.cfg.Addr { - log.Infof("MySQL addr %s in old master.info, but new %s, reset", c.master.Addr, c.cfg.Addr) - // may use another MySQL, reset - c.master = &masterInfo{} - } - c.master.Addr = c.cfg.Addr - - if err := c.prepareDumper(); err != nil { + if err = c.prepareDumper(); err != nil { return nil, errors.Trace(err) } @@ -83,6 +83,33 @@ func NewCanal(cfg *Config) (*Canal, error) { return nil, errors.Trace(err) } + // init table filter + if n := len(c.cfg.IncludeTableRegex); n > 0 { + c.includeTableRegex = make([]*regexp.Regexp, n) + for i, val := range c.cfg.IncludeTableRegex { + reg, err := regexp.Compile(val) + if err != nil { + return nil, errors.Trace(err) + } + c.includeTableRegex[i] = reg + } + } + + if n := len(c.cfg.ExcludeTableRegex); n > 0 { + c.excludeTableRegex = make([]*regexp.Regexp, n) + for i, val := range c.cfg.ExcludeTableRegex { + reg, err := regexp.Compile(val) + if err != nil { + return nil, errors.Trace(err) + } + c.excludeTableRegex[i] = reg + } + } + + if c.includeTableRegex != nil || c.excludeTableRegex != nil { + c.tableMatchCache = make(map[string]bool) + } + return c, nil } @@ -114,6 +141,15 @@ func (c *Canal) prepareDumper() error { c.dumper.AddTables(tableDB, tables...) } + charset := c.cfg.Charset + c.dumper.SetCharset(charset) + + c.dumper.SetWhere(c.cfg.Dump.Where) + c.dumper.SkipMasterData(c.cfg.Dump.SkipMasterData) + c.dumper.SetMaxAllowedPacket(c.cfg.Dump.MaxAllowedPacketMB) + // Use hex blob for mysqldump + c.dumper.SetHexBlob(true) + for _, ignoreTable := range c.cfg.Dump.IgnoreTables { if seps := strings.Split(ignoreTable, ","); len(seps) == 2 { c.dumper.AddIgnoreTables(seps[0], seps[1]) @@ -129,92 +165,208 @@ func (c *Canal) prepareDumper() error { return nil } -func (c *Canal) Start() error { - c.wg.Add(1) - go c.run() +// Run will first try to dump all data from MySQL master `mysqldump`, +// then sync from the binlog position in the dump data. +// It will run forever until meeting an error or Canal closed. +func (c *Canal) Run() error { + return c.run() +} - return nil +// RunFrom will sync from the binlog position directly, ignore mysqldump. +func (c *Canal) RunFrom(pos mysql.Position) error { + c.master.Update(pos) + + return c.Run() +} + +func (c *Canal) StartFromGTID(set mysql.GTIDSet) error { + c.master.UpdateGTIDSet(set) + + return c.Run() +} + +// Dump all data from MySQL master `mysqldump`, ignore sync binlog. +func (c *Canal) Dump() error { + if c.dumped { + return errors.New("the method Dump can't be called twice") + } + c.dumped = true + defer close(c.dumpDoneCh) + return c.dump() } func (c *Canal) run() error { - defer c.wg.Done() + defer func() { + c.cancel() + }() - if err := c.tryDump(); err != nil { - log.Errorf("canal dump mysql err: %v", err) - return errors.Trace(err) + c.master.UpdateTimestamp(uint32(time.Now().Unix())) + + if !c.dumped { + c.dumped = true + + err := c.tryDump() + close(c.dumpDoneCh) + + if err != nil { + log.Errorf("canal dump mysql err: %v", err) + return errors.Trace(err) + } } - close(c.dumpDoneCh) - - if err := c.startSyncBinlog(); err != nil { - if !c.isClosed() { - log.Errorf("canal start sync binlog err: %v", err) - } + if err := c.runSyncBinlog(); err != nil { + log.Errorf("canal start sync binlog err: %v", err) return errors.Trace(err) } return nil } -func (c *Canal) isClosed() bool { - return c.closed.Get() -} - func (c *Canal) Close() { - log.Infof("close canal") + log.Infof("closing canal") c.m.Lock() defer c.m.Unlock() - if c.isClosed() { - return - } - - c.closed.Set(true) - - close(c.quit) - + c.cancel() c.connLock.Lock() c.conn.Close() c.conn = nil c.connLock.Unlock() + c.syncer.Close() - if c.syncer != nil { - c.syncer.Close() - c.syncer = nil - } - - c.master.Close() - - c.wg.Wait() + c.eventHandler.OnPosSynced(c.master.Position(), true) } func (c *Canal) WaitDumpDone() <-chan struct{} { return c.dumpDoneCh } +func (c *Canal) Ctx() context.Context { + return c.ctx +} + +func (c *Canal) checkTableMatch(key string) bool { + // no filter, return true + if c.tableMatchCache == nil { + return true + } + + c.tableLock.RLock() + rst, ok := c.tableMatchCache[key] + c.tableLock.RUnlock() + if ok { + // cache hit + return rst + } + matchFlag := false + // check include + if c.includeTableRegex != nil { + for _, reg := range c.includeTableRegex { + if reg.MatchString(key) { + matchFlag = true + break + } + } + } + // check exclude + if matchFlag && c.excludeTableRegex != nil { + for _, reg := range c.excludeTableRegex { + if reg.MatchString(key) { + matchFlag = false + break + } + } + } + c.tableLock.Lock() + c.tableMatchCache[key] = matchFlag + c.tableLock.Unlock() + return matchFlag +} + func (c *Canal) GetTable(db string, table string) (*schema.Table, error) { key := fmt.Sprintf("%s.%s", db, table) - c.tableLock.Lock() + // if table is excluded, return error and skip parsing event or dump + if !c.checkTableMatch(key) { + return nil, ErrExcludedTable + } + c.tableLock.RLock() t, ok := c.tables[key] - c.tableLock.Unlock() + c.tableLock.RUnlock() if ok { return t, nil } + if c.cfg.DiscardNoMetaRowEvent { + c.tableLock.RLock() + lastTime, ok := c.errorTablesGetTime[key] + c.tableLock.RUnlock() + if ok && time.Now().Sub(lastTime) < UnknownTableRetryPeriod { + return nil, schema.ErrMissingTableMeta + } + } + t, err := schema.NewTable(c, db, table) if err != nil { - return nil, errors.Trace(err) + // check table not exists + if ok, err1 := schema.IsTableExist(c, db, table); err1 == nil && !ok { + return nil, schema.ErrTableNotExist + } + // work around : RDS HAHeartBeat + // ref : https://github.com/alibaba/canal/blob/master/parse/src/main/java/com/alibaba/otter/canal/parse/inbound/mysql/dbsync/LogEventConvert.java#L385 + // issue : https://github.com/alibaba/canal/issues/222 + // This is a common error in RDS that canal can't get HAHealthCheckSchema's meta, so we mock a table meta. + // If canal just skip and log error, as RDS HA heartbeat interval is very short, so too many HAHeartBeat errors will be logged. + if key == schema.HAHealthCheckSchema { + // mock ha_health_check meta + ta := &schema.Table{ + Schema: db, + Name: table, + Columns: make([]schema.TableColumn, 0, 2), + Indexes: make([]*schema.Index, 0), + } + ta.AddColumn("id", "bigint(20)", "", "") + ta.AddColumn("type", "char(1)", "", "") + c.tableLock.Lock() + c.tables[key] = ta + c.tableLock.Unlock() + return ta, nil + } + // if DiscardNoMetaRowEvent is true, we just log this error + if c.cfg.DiscardNoMetaRowEvent { + c.tableLock.Lock() + c.errorTablesGetTime[key] = time.Now() + c.tableLock.Unlock() + // log error and return ErrMissingTableMeta + log.Errorf("canal get table meta err: %v", errors.Trace(err)) + return nil, schema.ErrMissingTableMeta + } + return nil, err } c.tableLock.Lock() c.tables[key] = t + if c.cfg.DiscardNoMetaRowEvent { + // if get table info success, delete this key from errorTablesGetTime + delete(c.errorTablesGetTime, key) + } c.tableLock.Unlock() return t, nil } +// ClearTableCache clear table cache +func (c *Canal) ClearTableCache(db []byte, table []byte) { + key := fmt.Sprintf("%s.%s", db, table) + c.tableLock.Lock() + delete(c.tables, key) + if c.cfg.DiscardNoMetaRowEvent { + delete(c.errorTablesGetTime, key) + } + c.tableLock.Unlock() +} + // Check MySQL binlog row image, must be in FULL, MINIMAL, NOBLOB func (c *Canal) CheckBinlogRowImage(image string) error { // need to check MySQL binlog row image? full, minimal or noblob? @@ -246,34 +398,41 @@ func (c *Canal) checkBinlogRowFormat() error { } func (c *Canal) prepareSyncer() error { - seps := strings.Split(c.cfg.Addr, ":") - if len(seps) != 2 { - return errors.Errorf("invalid mysql addr format %s, must host:port", c.cfg.Addr) - } - - port, err := strconv.ParseUint(seps[1], 10, 16) - if err != nil { - return errors.Trace(err) - } - cfg := replication.BinlogSyncerConfig{ - ServerID: c.cfg.ServerID, - Flavor: c.cfg.Flavor, - Host: seps[0], - Port: uint16(port), - User: c.cfg.User, - Password: c.cfg.Password, + ServerID: c.cfg.ServerID, + Flavor: c.cfg.Flavor, + User: c.cfg.User, + Password: c.cfg.Password, + Charset: c.cfg.Charset, + HeartbeatPeriod: c.cfg.HeartbeatPeriod, + ReadTimeout: c.cfg.ReadTimeout, + UseDecimal: c.cfg.UseDecimal, + ParseTime: c.cfg.ParseTime, + SemiSyncEnabled: c.cfg.SemiSyncEnabled, } - c.syncer = replication.NewBinlogSyncer(&cfg) + if strings.Contains(c.cfg.Addr, "/") { + cfg.Host = c.cfg.Addr + } else { + seps := strings.Split(c.cfg.Addr, ":") + if len(seps) != 2 { + return errors.Errorf("invalid mysql addr format %s, must host:port", c.cfg.Addr) + } + + port, err := strconv.ParseUint(seps[1], 10, 16) + if err != nil { + return errors.Trace(err) + } + + cfg.Host = seps[0] + cfg.Port = uint16(port) + } + + c.syncer = replication.NewBinlogSyncer(cfg) return nil } -func (c *Canal) masterInfoPath() string { - return path.Join(c.cfg.DataDir, "master.info") -} - // Execute a SQL func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err error) { c.connLock.Lock() @@ -303,5 +462,13 @@ func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err } func (c *Canal) SyncedPosition() mysql.Position { - return c.master.Pos() + return c.master.Position() +} + +func (c *Canal) SyncedTimestamp() uint32 { + return c.master.timestamp +} + +func (c *Canal) SyncedGTIDSet() mysql.GTIDSet { + return c.master.GTIDSet() } diff --git a/vendor/github.com/siddontang/go-mysql/canal/canal_test.go b/vendor/github.com/siddontang/go-mysql/canal/canal_test.go old mode 100644 new mode 100755 index b83c79e..bd16bd2 --- a/vendor/github.com/siddontang/go-mysql/canal/canal_test.go +++ b/vendor/github.com/siddontang/go-mysql/canal/canal_test.go @@ -1,13 +1,15 @@ package canal import ( + "bytes" "flag" "fmt" - "os" "testing" + "time" - "github.com/ngaut/log" + "github.com/juju/errors" . "github.com/pingcap/check" + "github.com/siddontang/go-log/log" "github.com/siddontang/go-mysql/mysql" ) @@ -27,19 +29,28 @@ func (s *canalTestSuite) SetUpSuite(c *C) { cfg := NewDefaultConfig() cfg.Addr = fmt.Sprintf("%s:3306", *testHost) cfg.User = "root" + cfg.HeartbeatPeriod = 200 * time.Millisecond + cfg.ReadTimeout = 300 * time.Millisecond cfg.Dump.ExecutionPath = "mysqldump" cfg.Dump.TableDB = "test" cfg.Dump.Tables = []string{"canal_test"} + cfg.Dump.Where = "id>0" - os.RemoveAll(cfg.DataDir) + // include & exclude config + cfg.IncludeTableRegex = make([]string, 1) + cfg.IncludeTableRegex[0] = ".*\\.canal_test" + cfg.ExcludeTableRegex = make([]string, 2) + cfg.ExcludeTableRegex[0] = "mysql\\..*" + cfg.ExcludeTableRegex[1] = ".*\\..*_inner" var err error s.c, err = NewCanal(cfg) c.Assert(err, IsNil) - + s.execute(c, "DROP TABLE IF EXISTS test.canal_test") sql := ` CREATE TABLE IF NOT EXISTS test.canal_test ( - id int AUTO_INCREMENT, + id int AUTO_INCREMENT, + content blob DEFAULT NULL, name varchar(100), PRIMARY KEY(id) )ENGINE=innodb; @@ -48,16 +59,22 @@ func (s *canalTestSuite) SetUpSuite(c *C) { s.execute(c, sql) s.execute(c, "DELETE FROM test.canal_test") - s.execute(c, "INSERT INTO test.canal_test (name) VALUES (?), (?), (?)", "a", "b", "c") + s.execute(c, "INSERT INTO test.canal_test (content, name) VALUES (?, ?), (?, ?), (?, ?)", "1", "a", `\0\ndsfasdf`, "b", "", "c") s.execute(c, "SET GLOBAL binlog_format = 'ROW'") - s.c.RegRowsEventHandler(&testRowsEventHandler{}) - err = s.c.Start() - c.Assert(err, IsNil) + s.c.SetEventHandler(&testEventHandler{c: c}) + go func() { + err = s.c.Run() + c.Assert(err, IsNil) + }() } func (s *canalTestSuite) TearDownSuite(c *C) { + // To test the heartbeat and read timeout,so need to sleep 1 seconds without data transmission + c.Logf("Start testing the heartbeat and read timeout") + time.Sleep(time.Second) + if s.c != nil { s.c.Close() s.c = nil @@ -70,16 +87,19 @@ func (s *canalTestSuite) execute(c *C, query string, args ...interface{}) *mysql return r } -type testRowsEventHandler struct { +type testEventHandler struct { + DummyEventHandler + + c *C } -func (h *testRowsEventHandler) Do(e *RowsEvent) error { - log.Infof("%s %v\n", e.Action, e.Rows) +func (h *testEventHandler) OnRow(e *RowsEvent) error { + log.Infof("OnRow %s %v\n", e.Action, e.Rows) return nil } -func (h *testRowsEventHandler) String() string { - return "testRowsEventHandler" +func (h *testEventHandler) String() string { + return "testEventHandler" } func (s *canalTestSuite) TestCanal(c *C) { @@ -88,7 +108,126 @@ func (s *canalTestSuite) TestCanal(c *C) { for i := 1; i < 10; i++ { s.execute(c, "INSERT INTO test.canal_test (name) VALUES (?)", fmt.Sprintf("%d", i)) } + s.execute(c, "ALTER TABLE test.canal_test ADD `age` INT(5) NOT NULL AFTER `name`") + s.execute(c, "INSERT INTO test.canal_test (name,age) VALUES (?,?)", "d", "18") - err := s.c.CatchMasterPos(100) + err := s.c.CatchMasterPos(10 * time.Second) c.Assert(err, IsNil) } + +func (s *canalTestSuite) TestCanalFilter(c *C) { + // included + sch, err := s.c.GetTable("test", "canal_test") + c.Assert(err, IsNil) + c.Assert(sch, NotNil) + _, err = s.c.GetTable("not_exist_db", "canal_test") + c.Assert(errors.Trace(err), Not(Equals), ErrExcludedTable) + // excluded + sch, err = s.c.GetTable("test", "canal_test_inner") + c.Assert(errors.Cause(err), Equals, ErrExcludedTable) + c.Assert(sch, IsNil) + sch, err = s.c.GetTable("mysql", "canal_test") + c.Assert(errors.Cause(err), Equals, ErrExcludedTable) + c.Assert(sch, IsNil) + sch, err = s.c.GetTable("not_exist_db", "not_canal_test") + c.Assert(errors.Cause(err), Equals, ErrExcludedTable) + c.Assert(sch, IsNil) +} + +func TestCreateTableExp(t *testing.T) { + cases := []string{ + "CREATE TABLE `mydb.mytable` (`id` int(10)) ENGINE=InnoDB", + "CREATE TABLE `mytable` (`id` int(10)) ENGINE=InnoDB", + "CREATE TABLE IF NOT EXISTS `mytable` (`id` int(10)) ENGINE=InnoDB", + "CREATE TABLE IF NOT EXISTS mytable (`id` int(10)) ENGINE=InnoDB", + } + table := []byte("mytable") + db := []byte("mydb") + for _, s := range cases { + m := expCreateTable.FindSubmatch([]byte(s)) + mLen := len(m) + if m == nil || !bytes.Equal(m[mLen-1], table) || (len(m[mLen-2]) > 0 && !bytes.Equal(m[mLen-2], db)) { + t.Fatalf("TestCreateTableExp: case %s failed\n", s) + } + } +} + +func TestAlterTableExp(t *testing.T) { + cases := []string{ + "ALTER TABLE `mydb`.`mytable` ADD `field2` DATE NULL AFTER `field1`;", + "ALTER TABLE `mytable` ADD `field2` DATE NULL AFTER `field1`;", + "ALTER TABLE mydb.mytable ADD `field2` DATE NULL AFTER `field1`;", + "ALTER TABLE mytable ADD `field2` DATE NULL AFTER `field1`;", + "ALTER TABLE mydb.mytable ADD field2 DATE NULL AFTER `field1`;", + } + + table := []byte("mytable") + db := []byte("mydb") + for _, s := range cases { + m := expAlterTable.FindSubmatch([]byte(s)) + mLen := len(m) + if m == nil || !bytes.Equal(m[mLen-1], table) || (len(m[mLen-2]) > 0 && !bytes.Equal(m[mLen-2], db)) { + t.Fatalf("TestAlterTableExp: case %s failed\n", s) + } + } +} + +func TestRenameTableExp(t *testing.T) { + cases := []string{ + "rename table `mydb`.`mytable` to `mydb`.`mytable1`", + "rename table `mytable` to `mytable1`", + "rename table mydb.mytable to mydb.mytable1", + "rename table mytable to mytable1", + + "rename table `mydb`.`mytable` to `mydb`.`mytable2`, `mydb`.`mytable3` to `mydb`.`mytable1`", + "rename table `mytable` to `mytable2`, `mytable3` to `mytable1`", + "rename table mydb.mytable to mydb.mytable2, mydb.mytable3 to mydb.mytable1", + "rename table mytable to mytable2, mytable3 to mytable1", + } + table := []byte("mytable") + db := []byte("mydb") + for _, s := range cases { + m := expRenameTable.FindSubmatch([]byte(s)) + mLen := len(m) + if m == nil || !bytes.Equal(m[mLen-1], table) || (len(m[mLen-2]) > 0 && !bytes.Equal(m[mLen-2], db)) { + t.Fatalf("TestRenameTableExp: case %s failed\n", s) + } + } +} + +func TestDropTableExp(t *testing.T) { + cases := []string{ + "drop table test1", + "DROP TABLE test1", + "DROP TABLE test1", + "DROP table IF EXISTS test.test1", + "drop table `test1`", + "DROP TABLE `test1`", + "DROP table IF EXISTS `test`.`test1`", + "DROP TABLE `test1` /* generated by server */", + "DROP table if exists test1", + "DROP table if exists `test1`", + "DROP table if exists test.test1", + "DROP table if exists `test`.test1", + "DROP table if exists `test`.`test1`", + "DROP table if exists test.`test1`", + "DROP table if exists test.`test1`", + } + + table := []byte("test1") + for _, s := range cases { + m := expDropTable.FindSubmatch([]byte(s)) + mLen := len(m) + if m == nil { + t.Fatalf("TestDropTableExp: case %s failed\n", s) + return + } + if mLen < 4 { + t.Fatalf("TestDropTableExp: case %s failed\n", s) + return + } + if !bytes.Equal(m[mLen-1], table) { + t.Fatalf("TestDropTableExp: case %s failed\n", s) + } + } +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/config.go b/vendor/github.com/siddontang/go-mysql/canal/config.go index f9bb262..d10513c 100644 --- a/vendor/github.com/siddontang/go-mysql/canal/config.go +++ b/vendor/github.com/siddontang/go-mysql/canal/config.go @@ -7,6 +7,7 @@ import ( "github.com/BurntSushi/toml" "github.com/juju/errors" + "github.com/siddontang/go-mysql/mysql" ) type DumpConfig struct { @@ -23,8 +24,18 @@ type DumpConfig struct { // Ignore table format is db.table IgnoreTables []string `toml:"ignore_tables"` + // Dump only selected records. Quotes are mandatory + Where string `toml:"where"` + // If true, discard error msg, else, output to stderr DiscardErr bool `toml:"discard_err"` + + // Set true to skip --master-data if we have no privilege to do + // 'FLUSH TABLES WITH READ LOCK' + SkipMasterData bool `toml:"skip_master_data"` + + // Set to change the default max_allowed_packet size + MaxAllowedPacketMB int `toml:"max_allowed_packet_mb"` } type Config struct { @@ -32,11 +43,30 @@ type Config struct { User string `toml:"user"` Password string `toml:"password"` - ServerID uint32 `toml:"server_id"` - Flavor string `toml:"flavor"` - DataDir string `toml:"data_dir"` + Charset string `toml:"charset"` + ServerID uint32 `toml:"server_id"` + Flavor string `toml:"flavor"` + HeartbeatPeriod time.Duration `toml:"heartbeat_period"` + ReadTimeout time.Duration `toml:"read_timeout"` + + // IncludeTableRegex or ExcludeTableRegex should contain database name + // Only a table which matches IncludeTableRegex and dismatches ExcludeTableRegex will be processed + // eg, IncludeTableRegex : [".*\\.canal"], ExcludeTableRegex : ["mysql\\..*"] + // this will include all database's 'canal' table, except database 'mysql' + // Default IncludeTableRegex and ExcludeTableRegex are empty, this will include all tables + IncludeTableRegex []string `toml:"include_table_regex"` + ExcludeTableRegex []string `toml:"exclude_table_regex"` + + // discard row event without table meta + DiscardNoMetaRowEvent bool `toml:"discard_no_meta_row_event"` Dump DumpConfig `toml:"dump"` + + UseDecimal bool `toml:"use_decimal"` + ParseTime bool `toml:"parse_time"` + + // SemiSyncEnabled enables semi-sync or not. + SemiSyncEnabled bool `toml:"semi_sync_enabled"` } func NewConfigWithFile(name string) (*Config, error) { @@ -66,14 +96,14 @@ func NewDefaultConfig() *Config { c.User = "root" c.Password = "" - rand.Seed(time.Now().Unix()) - c.ServerID = uint32(rand.Intn(1000)) + 1001 + c.Charset = mysql.DEFAULT_CHARSET + c.ServerID = uint32(rand.New(rand.NewSource(time.Now().Unix())).Intn(1000)) + 1001 c.Flavor = "mysql" - c.DataDir = "./var" c.Dump.ExecutionPath = "mysqldump" c.Dump.DiscardErr = true + c.Dump.SkipMasterData = false return c } diff --git a/vendor/github.com/siddontang/go-mysql/canal/dump.go b/vendor/github.com/siddontang/go-mysql/canal/dump.go index 4adec2a..8dcac2b 100644 --- a/vendor/github.com/siddontang/go-mysql/canal/dump.go +++ b/vendor/github.com/siddontang/go-mysql/canal/dump.go @@ -1,12 +1,16 @@ package canal import ( + "encoding/hex" + "fmt" "strconv" + "strings" "time" "github.com/juju/errors" - "github.com/ngaut/log" - "github.com/siddontang/go-mysql/dump" + "github.com/shopspring/decimal" + "github.com/siddontang/go-log/log" + "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/schema" ) @@ -14,6 +18,7 @@ type dumpParseHandler struct { c *Canal name string pos uint64 + gset mysql.GTIDSet } func (h *dumpParseHandler) BinLog(name string, pos uint64) error { @@ -23,12 +28,18 @@ func (h *dumpParseHandler) BinLog(name string, pos uint64) error { } func (h *dumpParseHandler) Data(db string, table string, values []string) error { - if h.c.isClosed() { - return errCanalClosed + if err := h.c.ctx.Err(); err != nil { + return err } tableInfo, err := h.c.GetTable(db, table) if err != nil { + e := errors.Cause(err) + if e == ErrExcludedTable || + e == schema.ErrTableNotExist || + e == schema.ErrMissingTableMeta { + return nil + } log.Errorf("get %s.%s information err: %v", db, table, err) return errors.Trace(err) } @@ -38,32 +49,51 @@ func (h *dumpParseHandler) Data(db string, table string, values []string) error for i, v := range values { if v == "NULL" { vs[i] = nil + } else if v == "_binary ''" { + vs[i] = []byte{} } else if v[0] != '\'' { if tableInfo.Columns[i].Type == schema.TYPE_NUMBER { n, err := strconv.ParseInt(v, 10, 64) if err != nil { - log.Errorf("parse row %v at %d error %v, skip", values, i, err) - return dump.ErrSkip + return fmt.Errorf("parse row %v at %d error %v, int expected", values, i, err) } vs[i] = n } else if tableInfo.Columns[i].Type == schema.TYPE_FLOAT { f, err := strconv.ParseFloat(v, 64) if err != nil { - log.Errorf("parse row %v at %d error %v, skip", values, i, err) - return dump.ErrSkip + return fmt.Errorf("parse row %v at %d error %v, float expected", values, i, err) } vs[i] = f + } else if tableInfo.Columns[i].Type == schema.TYPE_DECIMAL { + if h.c.cfg.UseDecimal { + d, err := decimal.NewFromString(v) + if err != nil { + return fmt.Errorf("parse row %v at %d error %v, decimal expected", values, i, err) + } + vs[i] = d + } else { + f, err := strconv.ParseFloat(v, 64) + if err != nil { + return fmt.Errorf("parse row %v at %d error %v, float expected", values, i, err) + } + vs[i] = f + } + } else if strings.HasPrefix(v, "0x") { + buf, err := hex.DecodeString(v[2:]) + if err != nil { + return fmt.Errorf("parse row %v at %d error %v, hex literal expected", values, i, err) + } + vs[i] = string(buf) } else { - log.Errorf("parse row %v error, invalid type at %d, skip", values, i) - return dump.ErrSkip + return fmt.Errorf("parse row %v error, invalid type at %d", values, i) } } else { vs[i] = v[1 : len(v)-1] } } - events := newRowsEvent(tableInfo, InsertAction, [][]interface{}{vs}) - return h.c.travelRowsEventHandler(events) + events := newRowsEvent(tableInfo, InsertAction, [][]interface{}{vs}, nil) + return h.c.eventHandler.OnRow(events) } func (c *Canal) AddDumpDatabases(dbs ...string) { @@ -90,10 +120,64 @@ func (c *Canal) AddDumpIgnoreTables(db string, tables ...string) { c.dumper.AddIgnoreTables(db, tables...) } +func (c *Canal) dump() error { + if c.dumper == nil { + return errors.New("mysqldump does not exist") + } + + c.master.UpdateTimestamp(uint32(time.Now().Unix())) + + h := &dumpParseHandler{c: c} + // If users call StartFromGTID with empty position to start dumping with gtid, + // we record the current gtid position before dump starts. + // + // See tryDump() to see when dump is skipped. + if c.master.GTIDSet() != nil { + gset, err := c.GetMasterGTIDSet() + if err != nil { + return errors.Trace(err) + } + h.gset = gset + } + + if c.cfg.Dump.SkipMasterData { + pos, err := c.GetMasterPos() + if err != nil { + return errors.Trace(err) + } + log.Infof("skip master data, get current binlog position %v", pos) + h.name = pos.Name + h.pos = uint64(pos.Pos) + } + + start := time.Now() + log.Info("try dump MySQL and parse") + if err := c.dumper.DumpAndParse(h); err != nil { + return errors.Trace(err) + } + + pos := mysql.Position{Name: h.name, Pos: uint32(h.pos)} + c.master.Update(pos) + if err := c.eventHandler.OnPosSynced(pos, true); err != nil { + return errors.Trace(err) + } + var startPos fmt.Stringer = pos + if h.gset != nil { + c.master.UpdateGTIDSet(h.gset) + startPos = h.gset + } + log.Infof("dump MySQL and parse OK, use %0.2f seconds, start binlog replication at %s", + time.Now().Sub(start).Seconds(), startPos) + return nil +} + func (c *Canal) tryDump() error { - if len(c.master.Name) > 0 && c.master.Position > 0 { + pos := c.master.Position() + gset := c.master.GTIDSet() + if (len(pos.Name) > 0 && pos.Pos > 0) || + (gset != nil && gset.String() != "") { // we will sync with binlog name and position - log.Infof("skip dump, use last binlog replication pos (%s, %d)", c.master.Name, c.master.Position) + log.Infof("skip dump, use last binlog replication pos %s or GTID set %s", pos, gset) return nil } @@ -102,19 +186,5 @@ func (c *Canal) tryDump() error { return nil } - h := &dumpParseHandler{c: c} - - start := time.Now() - log.Info("try dump MySQL and parse") - if err := c.dumper.DumpAndParse(h); err != nil { - return errors.Trace(err) - } - - log.Infof("dump MySQL and parse OK, use %0.2f seconds, start binlog replication at (%s, %d)", - time.Now().Sub(start).Seconds(), h.name, h.pos) - - c.master.Update(h.name, uint32(h.pos)) - c.master.Save(true) - - return nil + return c.dump() } diff --git a/vendor/github.com/siddontang/go-mysql/canal/handler.go b/vendor/github.com/siddontang/go-mysql/canal/handler.go index e361181..4e47cb9 100644 --- a/vendor/github.com/siddontang/go-mysql/canal/handler.go +++ b/vendor/github.com/siddontang/go-mysql/canal/handler.go @@ -1,41 +1,41 @@ package canal import ( - "github.com/juju/errors" - "github.com/ngaut/log" "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/replication" ) -var ( - ErrHandleInterrupted = errors.New("do handler error, interrupted") -) - -type RowsEventHandler interface { - // Handle RowsEvent, if return ErrHandleInterrupted, canal will - // stop the sync - Do(e *RowsEvent) error +type EventHandler interface { + OnRotate(roateEvent *replication.RotateEvent) error + // OnTableChanged is called when the table is created, altered, renamed or dropped. + // You need to clear the associated data like cache with the table. + // It will be called before OnDDL. + OnTableChanged(schema string, table string) error + OnDDL(nextPos mysql.Position, queryEvent *replication.QueryEvent) error + OnRow(e *RowsEvent) error + OnXID(nextPos mysql.Position) error + OnGTID(gtid mysql.GTIDSet) error + // OnPosSynced Use your own way to sync position. When force is true, sync position immediately. + OnPosSynced(pos mysql.Position, force bool) error String() string } -func (c *Canal) RegRowsEventHandler(h RowsEventHandler) { - c.rsLock.Lock() - c.rsHandlers = append(c.rsHandlers, h) - c.rsLock.Unlock() +type DummyEventHandler struct { } -func (c *Canal) travelRowsEventHandler(e *RowsEvent) error { - c.rsLock.Lock() - defer c.rsLock.Unlock() - - var err error - for _, h := range c.rsHandlers { - if err = h.Do(e); err != nil && !mysql.ErrorEqual(err, ErrHandleInterrupted) { - log.Errorf("handle %v err: %v", h, err) - } else if mysql.ErrorEqual(err, ErrHandleInterrupted) { - log.Errorf("handle %v err, interrupted", h) - return ErrHandleInterrupted - } - - } +func (h *DummyEventHandler) OnRotate(*replication.RotateEvent) error { return nil } +func (h *DummyEventHandler) OnTableChanged(schema string, table string) error { return nil } +func (h *DummyEventHandler) OnDDL(nextPos mysql.Position, queryEvent *replication.QueryEvent) error { return nil } +func (h *DummyEventHandler) OnRow(*RowsEvent) error { return nil } +func (h *DummyEventHandler) OnXID(mysql.Position) error { return nil } +func (h *DummyEventHandler) OnGTID(mysql.GTIDSet) error { return nil } +func (h *DummyEventHandler) OnPosSynced(mysql.Position, bool) error { return nil } +func (h *DummyEventHandler) String() string { return "DummyEventHandler" } + +// `SetEventHandler` registers the sync handler, you must register your +// own handler before starting Canal. +func (c *Canal) SetEventHandler(h EventHandler) { + c.eventHandler = h +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/master.go b/vendor/github.com/siddontang/go-mysql/canal/master.go index 6897d05..10a230b 100644 --- a/vendor/github.com/siddontang/go-mysql/canal/master.go +++ b/vendor/github.com/siddontang/go-mysql/canal/master.go @@ -1,89 +1,66 @@ package canal import ( - "bytes" - "os" "sync" - "time" - "github.com/BurntSushi/toml" - "github.com/juju/errors" - "github.com/ngaut/log" + "github.com/siddontang/go-log/log" "github.com/siddontang/go-mysql/mysql" - "github.com/siddontang/go/ioutil2" ) type masterInfo struct { - Addr string `toml:"addr"` - Name string `toml:"bin_name"` - Position uint32 `toml:"bin_pos"` + sync.RWMutex - name string + pos mysql.Position - l sync.Mutex + gset mysql.GTIDSet - lastSaveTime time.Time + timestamp uint32 } -func loadMasterInfo(name string) (*masterInfo, error) { - var m masterInfo +func (m *masterInfo) Update(pos mysql.Position) { + log.Debugf("update master position %s", pos) - m.name = name - - f, err := os.Open(name) - if err != nil && !os.IsNotExist(errors.Cause(err)) { - return nil, errors.Trace(err) - } else if os.IsNotExist(errors.Cause(err)) { - return &m, nil - } - defer f.Close() - - _, err = toml.DecodeReader(f, &m) - - return &m, err + m.Lock() + m.pos = pos + m.Unlock() } -func (m *masterInfo) Save(force bool) error { - m.l.Lock() - defer m.l.Unlock() +func (m *masterInfo) UpdateTimestamp(ts uint32) { + log.Debugf("update master timestamp %s", ts) - n := time.Now() - if !force && n.Sub(m.lastSaveTime) < time.Second { + m.Lock() + m.timestamp = ts + m.Unlock() +} + +func (m *masterInfo) UpdateGTIDSet(gset mysql.GTIDSet) { + log.Debugf("update master gtid set %s", gset) + + m.Lock() + m.gset = gset + m.Unlock() +} + +func (m *masterInfo) Position() mysql.Position { + m.RLock() + defer m.RUnlock() + + return m.pos +} + +func (m *masterInfo) Timestamp() uint32 { + m.RLock() + defer m.RUnlock() + + return m.timestamp +} + +func (m *masterInfo) GTIDSet() mysql.GTIDSet { + m.RLock() + defer m.RUnlock() + + if m.gset == nil { return nil } - - var buf bytes.Buffer - e := toml.NewEncoder(&buf) - - e.Encode(m) - - var err error - if err = ioutil2.WriteFileAtomic(m.name, buf.Bytes(), 0644); err != nil { - log.Errorf("canal save master info to file %s err %v", m.name, err) - } - - m.lastSaveTime = n - - return errors.Trace(err) -} - -func (m *masterInfo) Update(name string, pos uint32) { - m.l.Lock() - m.Name = name - m.Position = pos - m.l.Unlock() -} - -func (m *masterInfo) Pos() mysql.Position { - var pos mysql.Position - m.l.Lock() - pos.Name = m.Name - pos.Pos = m.Position - m.l.Unlock() - - return pos -} - -func (m *masterInfo) Close() { - m.Save(true) + return m.gset.Clone() } diff --git a/vendor/github.com/siddontang/go-mysql/canal/rows.go b/vendor/github.com/siddontang/go-mysql/canal/rows.go index 5c5e467..e246ee5 100644 --- a/vendor/github.com/siddontang/go-mysql/canal/rows.go +++ b/vendor/github.com/siddontang/go-mysql/canal/rows.go @@ -3,16 +3,18 @@ package canal import ( "fmt" - "github.com/juju/errors" + "github.com/siddontang/go-mysql/replication" "github.com/siddontang/go-mysql/schema" ) +// The action name for sync. const ( UpdateAction = "update" InsertAction = "insert" DeleteAction = "delete" ) +// RowsEvent is the event for row replication. type RowsEvent struct { Table *schema.Table Action string @@ -22,35 +24,49 @@ type RowsEvent struct { // Two rows for one event, format is [before update row, after update row] // for update v0, only one row for a event, and we don't support this version. Rows [][]interface{} + // Header can be used to inspect the event + Header *replication.EventHeader } -func newRowsEvent(table *schema.Table, action string, rows [][]interface{}) *RowsEvent { +func newRowsEvent(table *schema.Table, action string, rows [][]interface{}, header *replication.EventHeader) *RowsEvent { e := new(RowsEvent) e.Table = table e.Action = action e.Rows = rows + e.Header = header + + e.handleUnsigned() return e } -// Get primary keys in one row for a table, a table may use multi fields as the PK -func GetPKValues(table *schema.Table, row []interface{}) ([]interface{}, error) { - indexes := table.PKColumns - if len(indexes) == 0 { - return nil, errors.Errorf("table %s has no PK", table) - } else if len(table.Columns) != len(row) { - return nil, errors.Errorf("table %s has %d columns, but row data %v len is %d", table, - len(table.Columns), row, len(row)) +func (r *RowsEvent) handleUnsigned() { + // Handle Unsigned Columns here, for binlog replication, we can't know the integer is unsigned or not, + // so we use int type but this may cause overflow outside sometimes, so we must convert to the really . + // unsigned type + if len(r.Table.UnsignedColumns) == 0 { + return } - values := make([]interface{}, 0, len(indexes)) - - for _, index := range indexes { - values = append(values, row[index]) + for i := 0; i < len(r.Rows); i++ { + for _, index := range r.Table.UnsignedColumns { + switch t := r.Rows[i][index].(type) { + case int8: + r.Rows[i][index] = uint8(t) + case int16: + r.Rows[i][index] = uint16(t) + case int32: + r.Rows[i][index] = uint32(t) + case int64: + r.Rows[i][index] = uint64(t) + case int: + r.Rows[i][index] = uint(t) + default: + // nothing to do + } + } } - - return values, nil } // String implements fmt.Stringer interface. diff --git a/vendor/github.com/siddontang/go-mysql/canal/sync.go b/vendor/github.com/siddontang/go-mysql/canal/sync.go index e76eea4..4146a2d 100644 --- a/vendor/github.com/siddontang/go-mysql/canal/sync.go +++ b/vendor/github.com/siddontang/go-mysql/canal/sync.go @@ -1,49 +1,68 @@ package canal import ( + "fmt" + "regexp" "time" - "golang.org/x/net/context" - "github.com/juju/errors" - "github.com/ngaut/log" + "github.com/satori/go.uuid" + "github.com/siddontang/go-log/log" "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/replication" + "github.com/siddontang/go-mysql/schema" ) -func (c *Canal) startSyncBinlog() error { - pos := mysql.Position{c.master.Name, c.master.Position} +var ( + expCreateTable = regexp.MustCompile("(?i)^CREATE\\sTABLE(\\sIF\\sNOT\\sEXISTS)?\\s`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}\\s.*") + expAlterTable = regexp.MustCompile("(?i)^ALTER\\sTABLE\\s.*?`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}\\s.*") + expRenameTable = regexp.MustCompile("(?i)^RENAME\\sTABLE\\s.*?`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}\\s{1,}TO\\s.*?") + expDropTable = regexp.MustCompile("(?i)^DROP\\sTABLE(\\sIF\\sEXISTS){0,1}\\s`{0,1}(.*?)`{0,1}\\.{0,1}`{0,1}([^`\\.]+?)`{0,1}(?:$|\\s)") + expTruncateTable = regexp.MustCompile("(?i)^TRUNCATE\\s+(?:TABLE\\s+)?(?:`?([^`\\s]+)`?\\.`?)?([^`\\s]+)`?") +) - log.Infof("start sync binlog at %v", pos) +func (c *Canal) startSyncer() (*replication.BinlogStreamer, error) { + gset := c.master.GTIDSet() + if gset == nil { + pos := c.master.Position() + s, err := c.syncer.StartSync(pos) + if err != nil { + return nil, errors.Errorf("start sync replication at binlog %v error %v", pos, err) + } + log.Infof("start sync binlog at binlog file %v", pos) + return s, nil + } else { + s, err := c.syncer.StartSyncGTID(gset) + if err != nil { + return nil, errors.Errorf("start sync replication at GTID set %v error %v", gset, err) + } + log.Infof("start sync binlog at GTID set %v", gset) + return s, nil + } +} - s, err := c.syncer.StartSync(pos) +func (c *Canal) runSyncBinlog() error { + s, err := c.startSyncer() if err != nil { - return errors.Errorf("start sync replication at %v error %v", pos, err) + return err } - timeout := time.Second - forceSavePos := false + savePos := false + force := false for { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - ev, err := s.GetEvent(ctx) - cancel() - - if err == context.DeadlineExceeded { - timeout = 2 * timeout - continue - } + ev, err := s.GetEvent(c.ctx) if err != nil { return errors.Trace(err) } + savePos = false + force = false + pos := c.master.Position() - timeout = time.Second - + curPos := pos.Pos //next binlog pos pos.Pos = ev.Header.LogPos - forceSavePos = false - // We only save position with RotateEvent and XIDEvent. // For RowsEvent, we can't save the position until meeting XIDEvent // which tells the whole transaction is over. @@ -52,24 +71,105 @@ func (c *Canal) startSyncBinlog() error { case *replication.RotateEvent: pos.Name = string(e.NextLogName) pos.Pos = uint32(e.Position) - // r.ev <- pos - forceSavePos = true - log.Infof("rotate binlog to %v", pos) + log.Infof("rotate binlog to %s", pos) + savePos = true + force = true + if err = c.eventHandler.OnRotate(e); err != nil { + return errors.Trace(err) + } case *replication.RowsEvent: // we only focus row based event - if err = c.handleRowsEvent(ev); err != nil { - log.Errorf("handle rows event error %v", err) - return errors.Trace(err) + err = c.handleRowsEvent(ev) + if err != nil { + e := errors.Cause(err) + // if error is not ErrExcludedTable or ErrTableNotExist or ErrMissingTableMeta, stop canal + if e != ErrExcludedTable && + e != schema.ErrTableNotExist && + e != schema.ErrMissingTableMeta { + log.Errorf("handle rows event at (%s, %d) error %v", pos.Name, curPos, err) + return errors.Trace(err) + } } continue case *replication.XIDEvent: + if e.GSet != nil { + c.master.UpdateGTIDSet(e.GSet) + } + savePos = true // try to save the position later + if err := c.eventHandler.OnXID(pos); err != nil { + return errors.Trace(err) + } + case *replication.MariadbGTIDEvent: + // try to save the GTID later + gtid, err := mysql.ParseMariadbGTIDSet(e.GTID.String()) + if err != nil { + return errors.Trace(err) + } + if err := c.eventHandler.OnGTID(gtid); err != nil { + return errors.Trace(err) + } + case *replication.GTIDEvent: + u, _ := uuid.FromBytes(e.SID) + gtid, err := mysql.ParseMysqlGTIDSet(fmt.Sprintf("%s:%d", u.String(), e.GNO)) + if err != nil { + return errors.Trace(err) + } + if err := c.eventHandler.OnGTID(gtid); err != nil { + return errors.Trace(err) + } + case *replication.QueryEvent: + if e.GSet != nil { + c.master.UpdateGTIDSet(e.GSet) + } + var ( + mb [][]byte + db []byte + table []byte + ) + regexps := []regexp.Regexp{*expCreateTable, *expAlterTable, *expRenameTable, *expDropTable, *expTruncateTable} + for _, reg := range regexps { + mb = reg.FindSubmatch(e.Query) + if len(mb) != 0 { + break + } + } + mbLen := len(mb) + if mbLen == 0 { + continue + } + + // the first last is table name, the second last is database name(if exists) + if len(mb[mbLen-2]) == 0 { + db = e.Schema + } else { + db = mb[mbLen-2] + } + table = mb[mbLen-1] + + savePos = true + force = true + c.ClearTableCache(db, table) + log.Infof("table structure changed, clear table cache: %s.%s\n", db, table) + if err = c.eventHandler.OnTableChanged(string(db), string(table)); err != nil && errors.Cause(err) != schema.ErrTableNotExist { + return errors.Trace(err) + } + + // Now we only handle Table Changed DDL, maybe we will support more later. + if err = c.eventHandler.OnDDL(pos, e); err != nil { + return errors.Trace(err) + } default: continue } - c.master.Update(pos.Name, pos.Pos) - c.master.Save(forceSavePos) + if savePos { + c.master.Update(pos) + c.master.UpdateTimestamp(ev.Header.Timestamp) + if err := c.eventHandler.OnPosSynced(pos, force); err != nil { + return errors.Trace(err) + } + } } return nil @@ -84,7 +184,7 @@ func (c *Canal) handleRowsEvent(e *replication.BinlogEvent) error { t, err := c.GetTable(schema, table) if err != nil { - return errors.Trace(err) + return err } var action string switch e.Header.EventType { @@ -97,25 +197,31 @@ func (c *Canal) handleRowsEvent(e *replication.BinlogEvent) error { default: return errors.Errorf("%s not supported now", e.Header.EventType) } - events := newRowsEvent(t, action, ev.Rows) - return c.travelRowsEventHandler(events) + events := newRowsEvent(t, action, ev.Rows, e.Header) + return c.eventHandler.OnRow(events) } -func (c *Canal) WaitUntilPos(pos mysql.Position, timeout int) error { - if timeout <= 0 { - timeout = 60 - } +func (c *Canal) FlushBinlog() error { + _, err := c.Execute("FLUSH BINARY LOGS") + return errors.Trace(err) +} - timer := time.NewTimer(time.Duration(timeout) * time.Second) +func (c *Canal) WaitUntilPos(pos mysql.Position, timeout time.Duration) error { + timer := time.NewTimer(timeout) for { select { case <-timer.C: - return errors.Errorf("wait position %v err", pos) + return errors.Errorf("wait position %v too long > %s", pos, timeout) default: - curpos := c.master.Pos() - if curpos.Compare(pos) >= 0 { + err := c.FlushBinlog() + if err != nil { + return errors.Trace(err) + } + curPos := c.master.Position() + if curPos.Compare(pos) >= 0 { return nil } else { + log.Debugf("master pos is %v, wait catching %v", curPos, pos) time.Sleep(100 * time.Millisecond) } } @@ -124,14 +230,46 @@ func (c *Canal) WaitUntilPos(pos mysql.Position, timeout int) error { return nil } -func (c *Canal) CatchMasterPos(timeout int) error { +func (c *Canal) GetMasterPos() (mysql.Position, error) { rr, err := c.Execute("SHOW MASTER STATUS") if err != nil { - return errors.Trace(err) + return mysql.Position{}, errors.Trace(err) } name, _ := rr.GetString(0, 0) pos, _ := rr.GetInt(0, 1) - return c.WaitUntilPos(mysql.Position{name, uint32(pos)}, timeout) + return mysql.Position{Name: name, Pos: uint32(pos)}, nil +} + +func (c *Canal) GetMasterGTIDSet() (mysql.GTIDSet, error) { + query := "" + switch c.cfg.Flavor { + case mysql.MariaDBFlavor: + query = "SELECT @@GLOBAL.gtid_current_pos" + default: + query = "SELECT @@GLOBAL.GTID_EXECUTED" + } + rr, err := c.Execute(query) + if err != nil { + return nil, errors.Trace(err) + } + gx, err := rr.GetString(0, 0) + if err != nil { + return nil, errors.Trace(err) + } + gset, err := mysql.ParseGTIDSet(c.cfg.Flavor, gx) + if err != nil { + return nil, errors.Trace(err) + } + return gset, nil +} + +func (c *Canal) CatchMasterPos(timeout time.Duration) error { + pos, err := c.GetMasterPos() + if err != nil { + return errors.Trace(err) + } + + return c.WaitUntilPos(pos, timeout) } diff --git a/vendor/github.com/siddontang/go-mysql/clear_vendor.sh b/vendor/github.com/siddontang/go-mysql/clear_vendor.sh new file mode 100755 index 0000000..81ba6b1 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/clear_vendor.sh @@ -0,0 +1,6 @@ +find vendor \( -type f -or -type l \) -not -name "*.go" -not -name "LICENSE" -not -name "*.s" -not -name "PATENTS" -not -name "*.h" -not -name "*.c" | xargs -I {} rm {} +# delete all test files +find vendor -type f -name "*_generated.go" | xargs -I {} rm {} +find vendor -type f -name "*_test.go" | xargs -I {} rm {} +find vendor -type d -name "_vendor" | xargs -I {} rm -rf {} +find vendor -type d -empty | xargs -I {} rm -rf {} \ No newline at end of file diff --git a/vendor/github.com/siddontang/go-mysql/client/auth.go b/vendor/github.com/siddontang/go-mysql/client/auth.go index 85b688c..5ba9c9f 100644 --- a/vendor/github.com/siddontang/go-mysql/client/auth.go +++ b/vendor/github.com/siddontang/go-mysql/client/auth.go @@ -4,12 +4,29 @@ import ( "bytes" "crypto/tls" "encoding/binary" + "fmt" "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/packet" ) +const defaultAuthPluginName = AUTH_NATIVE_PASSWORD + +// defines the supported auth plugins +var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD} + +// helper function to determine what auth methods are allowed by this client +func authPluginAllowed(pluginName string) bool { + for _, p := range supportedAuthPlugins { + if pluginName == p { + return true + } + } + return false +} + +// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake func (c *Conn) readInitialHandshake() error { data, err := c.ReadPacket() if err != nil { @@ -24,39 +41,44 @@ func (c *Conn) readInitialHandshake() error { return errors.Errorf("invalid protocol version %d, must >= 10", data[0]) } - //skip mysql version - //mysql version end with 0x00 + // skip mysql version + // mysql version end with 0x00 pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 - //connection id length is 4 + // connection id length is 4 c.connectionID = uint32(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 c.salt = []byte{} c.salt = append(c.salt, data[pos:pos+8]...) - //skip filter + // skip filter pos += 8 + 1 - //capability lower 2 bytes + // capability lower 2 bytes c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2])) - + // check protocol + if c.capability&CLIENT_PROTOCOL_41 == 0 { + return errors.New("the MySQL server can not support protocol 41 and above required by the client") + } + if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil { + return errors.New("the MySQL Server does not support TLS required by the client") + } pos += 2 if len(data) > pos { - //skip server charset + // skip server charset //c.charset = data[pos] pos += 1 c.status = binary.LittleEndian.Uint16(data[pos : pos+2]) pos += 2 - + // capability flags (upper 2 bytes) c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability - pos += 2 - //skip auth data len or [00] - //skip reserved (all [00]) + // skip auth data len or [00] + // skip reserved (all [00]) pos += 10 + 1 // The documentation is ambiguous about the length. @@ -64,78 +86,131 @@ func (c *Conn) readInitialHandshake() error { // mysql-proxy also use 12 // which is not documented but seems to work. c.salt = append(c.salt, data[pos:pos+12]...) + pos += 13 + // auth plugin + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + c.authPluginName = string(data[pos : pos+end]) + } else { + c.authPluginName = string(data[pos:]) + } + } + + // if server gives no default auth plugin name, use a client default + if c.authPluginName == "" { + c.authPluginName = defaultAuthPluginName } return nil } +// generate auth response data according to auth plugin +// +// NOTE: the returned boolean value indicates whether to add a \NUL to the end of data. +// it is quite tricky because MySQl server expects different formats of responses in different auth situations. +// here the \NUL needs to be added when sending back the empty password or cleartext password in 'sha256_password' +// authentication. +func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) { + // password hashing + switch c.authPluginName { + case AUTH_NATIVE_PASSWORD: + return CalcPassword(authData[:20], []byte(c.password)), false, nil + case AUTH_CACHING_SHA2_PASSWORD: + return CalcCachingSha2Password(authData, c.password), false, nil + case AUTH_SHA256_PASSWORD: + if len(c.password) == 0 { + return nil, true, nil + } + if c.tlsConfig != nil || c.proto == "unix" { + // write cleartext auth packet + // see: https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html + return []byte(c.password), true, nil + } else { + // request public key from server + // see: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html + return []byte{1}, false, nil + } + default: + // not reachable + return nil, false, fmt.Errorf("auth plugin '%s' is not supported", c.authPluginName) + } +} + +// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (c *Conn) writeAuthHandshake() error { + if !authPluginAllowed(c.authPluginName) { + return fmt.Errorf("unknow auth plugin name '%s'", c.authPluginName) + } // Adjust client capability flags based on server support capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | - CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG + CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH | c.capability&CLIENT_LONG_FLAG // To enable TLS / SSL - if c.TLSConfig != nil { - capability |= CLIENT_PLUGIN_AUTH + if c.tlsConfig != nil { capability |= CLIENT_SSL } - capability &= c.capability + auth, addNull, err := c.genAuthResponse(c.salt) + if err != nil { + return err + } + + // encode length of the auth plugin data + // here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte + // see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer + var authRespLEIBuf [9]byte + authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth))) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + } //packet length - //capbility 4 + //capability 4 //max-packet size 4 //charset 1 //reserved all[0] 23 - length := 4 + 4 + 1 + 23 - //username - length += len(c.user) + 1 - - //we only support secure connection - auth := CalcPassword(c.salt, []byte(c.password)) - - length += 1 + len(auth) - + //auth + //mysql_native_password + null-terminated + length := 4 + 4 + 1 + 23 + len(c.user) + 1 + len(authRespLEI) + len(auth) + 21 + 1 + if addNull { + length++ + } + // db name if len(c.db) > 0 { capability |= CLIENT_CONNECT_WITH_DB - length += len(c.db) + 1 } - // mysql_native_password + null-terminated - length += 21 + 1 - - c.capability = capability - data := make([]byte, length+4) - //capability [32 bit] + // capability [32 bit] data[4] = byte(capability) data[5] = byte(capability >> 8) data[6] = byte(capability >> 16) data[7] = byte(capability >> 24) - //MaxPacketSize [32 bit] (none) - //data[8] = 0x00 - //data[9] = 0x00 - //data[10] = 0x00 - //data[11] = 0x00 + // MaxPacketSize [32 bit] (none) + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 - //Charset [1 byte] - //use default collation id 33 here, is utf-8 + // Charset [1 byte] + // use default collation id 33 here, is utf-8 data[12] = byte(DEFAULT_COLLATION_ID) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest - if c.TLSConfig != nil { + if c.tlsConfig != nil { // Send TLS / SSL request packet if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS - tlsConn := tls.Client(c.Conn.Conn, c.TLSConfig) + tlsConn := tls.Client(c.Conn.Conn, c.tlsConfig) if err := tlsConn.Handshake(); err != nil { return err } @@ -145,10 +220,13 @@ func (c *Conn) writeAuthHandshake() error { c.Sequence = currentSequence } - //Filler [23 bytes] (all 0x00) - pos := 13 + 23 + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } - //User [null terminated string] + // User [null terminated string] if len(c.user) > 0 { pos += copy(data[pos:], c.user) } @@ -156,8 +234,12 @@ func (c *Conn) writeAuthHandshake() error { pos++ // auth [length encoded integer] - data[pos] = byte(len(auth)) - pos += 1 + copy(data[pos+1:], auth) + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], auth) + if addNull { + data[pos] = 0x00 + pos++ + } // db [null terminated string] if len(c.db) > 0 { @@ -167,7 +249,7 @@ func (c *Conn) writeAuthHandshake() error { } // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], c.authPluginName) data[pos] = 0x00 return c.WritePacket(data) diff --git a/vendor/github.com/siddontang/go-mysql/client/client_test.go b/vendor/github.com/siddontang/go-mysql/client/client_test.go index 85dd8f2..04bfdb2 100644 --- a/vendor/github.com/siddontang/go-mysql/client/client_test.go +++ b/vendor/github.com/siddontang/go-mysql/client/client_test.go @@ -1,41 +1,56 @@ package client import ( - "crypto/tls" "flag" "fmt" "strings" "testing" + "github.com/juju/errors" . "github.com/pingcap/check" + "github.com/siddontang/go-mysql/test_util/test_keys" "github.com/siddontang/go-mysql/mysql" ) var testHost = flag.String("host", "127.0.0.1", "MySQL server host") -var testPort = flag.Int("port", 3306, "MySQL server port") +// We cover the whole range of MySQL server versions using docker-compose to bind them to different ports for testing. +// MySQL is constantly updating auth plugin to make it secure: +// starting from MySQL 8.0.4, a new auth plugin is introduced, causing plain password auth to fail with error: +// ERROR 1251 (08004): Client does not support authentication protocol requested by server; consider upgrading MySQL client +// Hint: use docker-compose to start corresponding MySQL docker containers and add the their ports here +var testPort = flag.String("port", "3306", "MySQL server port") // choose one or more form 5561,5641,3306,5722,8003,8012,8013, e.g. '3306,5722,8003' var testUser = flag.String("user", "root", "MySQL user") var testPassword = flag.String("pass", "", "MySQL password") var testDB = flag.String("db", "test", "MySQL test database") func Test(t *testing.T) { + segs := strings.Split(*testPort, ",") + for _, seg := range segs { + Suite(&clientTestSuite{port: seg}) + } TestingT(t) } type clientTestSuite struct { - c *Conn + c *Conn + port string } -var _ = Suite(&clientTestSuite{}) - func (s *clientTestSuite) SetUpSuite(c *C) { var err error - addr := fmt.Sprintf("%s:%d", *testHost, *testPort) - s.c, err = Connect(addr, *testUser, *testPassword, *testDB) + addr := fmt.Sprintf("%s:%s", *testHost, s.port) + s.c, err = Connect(addr, *testUser, *testPassword, "") if err != nil { c.Fatal(err) } + _, err = s.c.Execute("CREATE DATABASE IF NOT EXISTS " + *testDB) + c.Assert(err, IsNil) + + _, err = s.c.Execute("USE " + *testDB) + c.Assert(err, IsNil) + s.testConn_CreateTable(c) s.testStmt_CreateTable(c) } @@ -78,12 +93,15 @@ func (s *clientTestSuite) TestConn_Ping(c *C) { c.Assert(err, IsNil) } -func (s *clientTestSuite) TestConn_TLS(c *C) { +// NOTE for MySQL 5.5 and 5.6, server side has to config SSL to pass the TLS test, otherwise, it will throw error that +// MySQL server does not support TLS required by the client. However, for MySQL 5.7 and above, auto generated certificates +// are used by default so that manual config is no longer necessary. +func (s *clientTestSuite) TestConn_TLS_Verify(c *C) { // Verify that the provided tls.Config is used when attempting to connect to mysql. // An empty tls.Config will result in a connection error. - addr := fmt.Sprintf("%s:%d", *testHost, *testPort) + addr := fmt.Sprintf("%s:%s", *testHost, s.port) _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { - c.TLSConfig = &tls.Config{} + c.UseSSL(false) }) if err == nil { c.Fatal("expected error") @@ -91,7 +109,34 @@ func (s *clientTestSuite) TestConn_TLS(c *C) { expected := "either ServerName or InsecureSkipVerify must be specified in the tls.Config" if !strings.Contains(err.Error(), expected) { - c.Fatal("expected '%s' to contain '%s'", err.Error(), expected) + c.Fatalf("expected '%s' to contain '%s'", err.Error(), expected) + } +} + +func (s *clientTestSuite) TestConn_TLS_Skip_Verify(c *C) { + // An empty tls.Config will result in a connection error but we can configure to skip it. + addr := fmt.Sprintf("%s:%s", *testHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + c.UseSSL(true) + }) + c.Assert(err, Equals, nil) +} + +func (s *clientTestSuite) TestConn_TLS_Certificate(c *C) { + // This test uses the TLS suite in 'go-mysql/docker/resources'. The certificates are not valid for any names. + // And if server uses auto-generated certificates, it will be an error like: + // "x509: certificate is valid for MySQL_Server_8.0.12_Auto_Generated_Server_Certificate, not not-a-valid-name" + tlsConfig := NewClientTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, false, "not-a-valid-name") + addr := fmt.Sprintf("%s:%s", *testHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + c.SetTLSConfig(tlsConfig) + }) + if err == nil { + c.Fatal("expected error") + } + if !strings.Contains(errors.Details(err), "certificate is not valid for any names") && + !strings.Contains(errors.Details(err), "certificate is valid for") { + c.Fatalf("expected errors for server name verification, but got unknown error: %s", errors.Details(err)) } } @@ -349,4 +394,4 @@ func (s *clientTestSuite) TestStmt_Trans(c *C) { str, _ = r.GetString(0, 0) c.Assert(str, Equals, `abc`) -} +} \ No newline at end of file diff --git a/vendor/github.com/siddontang/go-mysql/client/conn.go b/vendor/github.com/siddontang/go-mysql/client/conn.go index 54ee3f0..b015b43 100644 --- a/vendor/github.com/siddontang/go-mysql/client/conn.go +++ b/vendor/github.com/siddontang/go-mysql/client/conn.go @@ -18,7 +18,8 @@ type Conn struct { user string password string db string - TLSConfig *tls.Config + tlsConfig *tls.Config + proto string capability uint32 @@ -26,7 +27,8 @@ type Conn struct { charset string - salt []byte + salt []byte + authPluginName string connectionID uint32 } @@ -56,6 +58,7 @@ func Connect(addr string, user string, password string, dbName string, options . c.user = user c.password = password c.db = dbName + c.proto = proto //use default charset here, utf-8 c.charset = DEFAULT_CHARSET @@ -85,7 +88,7 @@ func (c *Conn) handshake() error { return errors.Trace(err) } - if _, err := c.readOK(); err != nil { + if err := c.handleAuthResult(); err != nil { c.Close() return errors.Trace(err) } @@ -109,6 +112,18 @@ func (c *Conn) Ping() error { return nil } +// use default SSL +// pass to options when connect +func (c *Conn) UseSSL(insecureSkipVerify bool) { + c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify} +} + +// use user-specified TLS config +// pass to options when connect +func (c *Conn) SetTLSConfig(config *tls.Config) { + c.tlsConfig = config +} + func (c *Conn) UseDB(dbName string) error { if c.db == dbName { return nil diff --git a/vendor/github.com/siddontang/go-mysql/client/resp.go b/vendor/github.com/siddontang/go-mysql/client/resp.go index 4e5f855..71aa1bc 100644 --- a/vendor/github.com/siddontang/go-mysql/client/resp.go +++ b/vendor/github.com/siddontang/go-mysql/client/resp.go @@ -1,8 +1,14 @@ package client +import "C" import ( "encoding/binary" + "bytes" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go/hack" @@ -32,7 +38,7 @@ func (c *Conn) isEOFPacket(data []byte) bool { func (c *Conn) handleOKPacket(data []byte) (*Result, error) { var n int - var pos int = 1 + var pos = 1 r := new(Result) @@ -64,7 +70,7 @@ func (c *Conn) handleOKPacket(data []byte) (*Result, error) { func (c *Conn) handleErrorPacket(data []byte) error { e := new(MyError) - var pos int = 1 + var pos = 1 e.Code = binary.LittleEndian.Uint16(data[pos:]) pos += 2 @@ -81,6 +87,116 @@ func (c *Conn) handleErrorPacket(data []byte) error { return e } +func (c *Conn) handleAuthResult() error { + data, switchToPlugin, err := c.readAuthResult() + if err != nil { + return err + } + // handle auth switch, only support 'sha256_password', and 'caching_sha2_password' + if switchToPlugin != "" { + //fmt.Printf("now switching auth plugin to '%s'\n", switchToPlugin) + if data == nil { + data = c.salt + } else { + copy(c.salt, data) + } + c.authPluginName = switchToPlugin + auth, addNull, err := c.genAuthResponse(data) + if err = c.WriteAuthSwitchPacket(auth, addNull); err != nil { + return err + } + + // Read Result Packet + data, switchToPlugin, err = c.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if switchToPlugin != "" { + return errors.Errorf("can not switch auth plugin more than once") + } + } + + // handle caching_sha2_password + if c.authPluginName == AUTH_CACHING_SHA2_PASSWORD { + if data == nil { + return nil // auth already succeeded + } + if data[0] == CACHE_SHA2_FAST_AUTH { + if _, err = c.readOK(); err == nil { + return nil // auth successful + } + } else if data[0] == CACHE_SHA2_FULL_AUTH { + // need full authentication + if c.tlsConfig != nil || c.proto == "unix" { + if err = c.WriteClearAuthPacket(c.password); err != nil { + return err + } + } else { + if err = c.WritePublicKeyAuthPacket(c.password, c.salt); err != nil { + return err + } + } + } else { + errors.Errorf("invalid packet") + } + } else if c.authPluginName == AUTH_SHA256_PASSWORD { + if len(data) == 0 { + return nil // auth already succeeded + } + block, _ := pem.Decode(data) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + // send encrypted password + err = c.WriteEncryptedPassword(c.password, c.salt, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + _, err = c.readOK() + return err + } + return nil +} + +func (c *Conn) readAuthResult() ([]byte, string, error) { + data, err := c.ReadPacket() + if err != nil { + return nil, "", err + } + + // see: https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + // packet indicator + switch data[0] { + + case OK_HEADER: + _, err := c.handleOKPacket(data) + return nil, "", err + + case MORE_DATE_HEADER: + return data[1:], "", err + + case EOF_HEADER: + // server wants to switch auth + if len(data) < 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, AUTH_MYSQL_OLD_PASSWORD, nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", errors.New("invalid packet") + } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", c.handleErrorPacket(data) + } +} + func (c *Conn) readOK() (*Result, error) { data, err := c.ReadPacket() if err != nil { diff --git a/vendor/github.com/siddontang/go-mysql/client/tls.go b/vendor/github.com/siddontang/go-mysql/client/tls.go new file mode 100644 index 0000000..3772a50 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/client/tls.go @@ -0,0 +1,28 @@ +package client + +import ( + "crypto/tls" + "crypto/x509" +) + +// generate TLS config for client side +// if insecureSkipVerify is set to true, serverName will not be validated +func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool, serverName string) *tls.Config { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPem) { + panic("failed to add ca PEM") + } + + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + panic(err) + } + + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: pool, + InsecureSkipVerify: insecureSkipVerify, + ServerName: serverName, + } + return config +} diff --git a/vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go b/vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go index 3990d00..fa5b5cf 100644 --- a/vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go +++ b/vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go @@ -23,6 +23,6 @@ func main() { err := p.ParseFile(*name, *offset, f) if err != nil { - println(err) + println(err.Error()) } } diff --git a/vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go b/vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go index 4d0c1df..6f60a4a 100644 --- a/vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go +++ b/vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go @@ -7,8 +7,10 @@ import ( "os/signal" "strings" "syscall" + "time" "github.com/siddontang/go-mysql/canal" + "github.com/siddontang/go-mysql/mysql" ) var host = flag.String("host", "127.0.0.1", "MySQL host") @@ -18,8 +20,6 @@ var password = flag.String("password", "", "MySQL password") var flavor = flag.String("flavor", "mysql", "Flavor: mysql or mariadb") -var dataDir = flag.String("data-dir", "./var", "Path to store data, like master.info") - var serverID = flag.Int("server-id", 101, "Unique Server ID") var mysqldump = flag.String("mysqldump", "mysqldump", "mysqldump execution path") @@ -28,6 +28,12 @@ var tables = flag.String("tables", "", "dump tables, seperated by comma, will ov var tableDB = flag.String("table_db", "test", "database for dump tables") var ignoreTables = flag.String("ignore_tables", "", "ignore tables, must be database.table format, separated by comma") +var startName = flag.String("bin_name", "", "start sync from binlog name") +var startPos = flag.Uint("bin_pos", 0, "start sync from binlog position of") + +var heartbeatPeriod = flag.Duration("heartbeat", 60*time.Second, "master heartbeat period") +var readTimeout = flag.Duration("read_timeout", 90*time.Second, "connection read timeout") + func main() { flag.Parse() @@ -36,8 +42,10 @@ func main() { cfg.User = *user cfg.Password = *password cfg.Flavor = *flavor - cfg.DataDir = *dataDir + cfg.UseDecimal = true + cfg.ReadTimeout = *readTimeout + cfg.HeartbeatPeriod = *heartbeatPeriod cfg.ServerID = uint32(*serverID) cfg.Dump.ExecutionPath = *mysqldump cfg.Dump.DiscardErr = false @@ -65,14 +73,20 @@ func main() { c.AddDumpDatabases(subs...) } - c.RegRowsEventHandler(&handler{}) + c.SetEventHandler(&handler{}) - err = c.Start() - if err != nil { - fmt.Printf("start canal err %V", err) - os.Exit(1) + startPos := mysql.Position{ + Name: *startName, + Pos: uint32(*startPos), } + go func() { + err = c.RunFrom(startPos) + if err != nil { + fmt.Printf("start canal err %v", err) + } + }() + sc := make(chan os.Signal, 1) signal.Notify(sc, os.Kill, @@ -88,9 +102,10 @@ func main() { } type handler struct { + canal.DummyEventHandler } -func (h *handler) Do(e *canal.RowsEvent) error { +func (h *handler) OnRow(e *canal.RowsEvent) error { fmt.Printf("%v\n", e) return nil diff --git a/vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go b/vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go index 5521c97..2c19c87 100644 --- a/vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go +++ b/vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go @@ -4,12 +4,11 @@ package main import ( + "context" "flag" "fmt" "os" - "golang.org/x/net/context" - "github.com/juju/errors" "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/replication" @@ -41,13 +40,14 @@ func main() { Port: uint16(*port), User: *user, Password: *password, - RawModeEanbled: *rawMode, + RawModeEnabled: *rawMode, SemiSyncEnabled: *semiSync, + UseDecimal: true, } - b := replication.NewBinlogSyncer(&cfg) + b := replication.NewBinlogSyncer(cfg) - pos := mysql.Position{*file, uint32(*pos)} + pos := mysql.Position{Name: *file, Pos: uint32(*pos)} if len(*backupPath) > 0 { // Backup will always use RawMode. err := b.StartBackup(*backupPath, pos, 0) @@ -65,6 +65,11 @@ func main() { for { e, err := s.GetEvent(context.Background()) if err != nil { + // Try to output all left events + events := s.DumpEvents() + for _, e := range events { + e.Dump(os.Stdout) + } fmt.Printf("Get event error: %v\n", errors.ErrorStack(err)) return } diff --git a/vendor/github.com/siddontang/go-mysql/docker/docker-compose.yaml b/vendor/github.com/siddontang/go-mysql/docker/docker-compose.yaml new file mode 100644 index 0000000..151786e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/docker-compose.yaml @@ -0,0 +1,80 @@ +version: '3' +services: + + mysql-5.5.61: + image: "mysql:5.5.61" + container_name: "mysql-server-5.5.61" + ports: + - "5561:3306" + command: --ssl=TRUE --ssl-ca=/usr/local/mysql/ca.pem --ssl-cert=/usr/local/mysql/server-cert.pem --ssl-key=/usr/local/mysql/server-key.pem + volumes: + - ./resources/ca.pem:/usr/local/mysql/ca.pem + - ./resources/server-cert.pem:/usr/local/mysql/server-cert.pem + - ./resources/server-key.pem:/usr/local/mysql/server-key.pem + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-5.6.41: + image: "mysql:5.6.41" + container_name: "mysql-server-5.6.41" + ports: + - "5641:3306" + command: --ssl=TRUE --ssl-ca=/usr/local/mysql/ca.pem --ssl-cert=/usr/local/mysql/server-cert.pem --ssl-key=/usr/local/mysql/server-key.pem + volumes: + - ./resources/ca.pem:/usr/local/mysql/ca.pem + - ./resources/server-cert.pem:/usr/local/mysql/server-cert.pem + - ./resources/server-key.pem:/usr/local/mysql/server-key.pem + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-default: + image: "mysql:5.7.22" + container_name: "mysql-server-default" + ports: + - "3306:3306" + command: ["mysqld", "--log-bin=mysql-bin", "--server-id=1"] + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-5.7.22: + image: "mysql:5.7.22" + container_name: "mysql-server-5.7.22" + ports: + - "5722:3306" + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-8.0.3: + image: "mysql:8.0.3" + container_name: "mysql-server-8.0.3" + ports: + - "8003:3306" + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-8.0.12: + image: "mysql:8.0.12" + container_name: "mysql-server-8.0.12" + ports: + - "8012:3306" + environment: + #- MYSQL_ROOT_PASSWORD=abc123 + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-8.0.12-sha256: + image: "mysql:8.0.12" + container_name: "mysql-server-8.0.12-sha256" + ports: + - "8013:3306" + entrypoint: ['/entrypoint.sh', '--default-authentication-plugin=sha256_password'] + environment: + #- MYSQL_ROOT_PASSWORD=abc123 + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + diff --git a/vendor/github.com/siddontang/go-mysql/docker/resources/ca.key b/vendor/github.com/siddontang/go-mysql/docker/resources/ca.key new file mode 100644 index 0000000..8344ed2 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/resources/ca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QL +suT1fGf5zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eo +xrS64GSbPVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKv +lpEAtRAe4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7 +IeXUc456yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioO +NaePQKFR/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABAoIBAC2U0jponRiGmgIl +gohw6+D+6pNeaKAAUkwYbKXJZ3noWLFr4T3GDTg9WDqvcvJg+rT9NvZxdCW3tDc5 +CVBcwO1g9PVcUEaRqcme3EhrxKdQQ76QmjUGeQf1ktd+YnmiZ1kOnGLtZ9/gsYpQ +06iGSIOX3+xA4BQOhEAPCOShMjYv+pWvWGhZCSmeulKulNVPBbG2H1I9EoT5Wd+Q +8LUfgZOuUXrtcsuvEf2XeacCo0pUbjx8ErhDHP6aPasFAXq15Bm8DnsUOrrsjcLy +sPy/mHwpd6kTw+O3EzjTdaYSFRoDSpfpIS5Bk+yicdxOmTwp1pzDu6HyYnuOnc9Q +JQ8HvlECgYEA2z+1HKVz5k7NYyRihW4l30vAcAGcgG1RObB6DmLbGu4MPvMymgLO +1QhYjlCcKfRHhVS2864op3Oba2fIgCc2am0DIQQ6kZ23ick78aj9G2ZXYpdpIPLu +Kl1AZHj6XDrOPVqidwcE6iYHLLWp9x4Atgw5d44XmhQ0kwrqAfccOX8CgYEAzxnl +7Uu+v5WI3hBVJpxS6eoS1TdztVIJaumyE43pBoHEuJrp4MRf0Lu2DiDpH8R3/RoE +o+ykn6xzphYwUopYaCWzYTKoXvxCvmLkDjHcpdzLtwWbKG+MJih2nTADEDI7sK4e +a3IU8miK6FeqkQHfs/5dlQa8q31yxiukw0qQEP0CgYAtLg6jTZD5l6mJUZkfx9f0 +EMciDaLzcBN54Nz2E/b0sLNDUZhO1l9K1QJyqTfVCWqnlhJxWqU0BIW1d1iA2BPF +kJtBdX6gPTDyKs64eMtXlxpQzcSzLnxXrIm1apyk3tVbHU83WfHwUk/OLc1NiBg7 +a394HIbOkHVZC7m3F/Xv/wKBgQDHrM2du8D+kJs0l4SxxFjAxPlBb8R01tLTrNwP +tGwu5OEZp+rE1jEXXFRMTPjXsyKI+hPtRJT4ilm6kXwnqNFSIL9RgHkLk6Z6T3hY +I0T8+ePD43jURLBYffzW0tqxO+2HDGmx6H0/twHuv89pHehkb2Qk8ijoIvyNCrlB +vVsntQKBgCK04nbb+G45D6TKCcZ6XKT/+qneJQE5cfvHl5EqrfjSmlnEUpJjJfyc +6Q1PtXtWOtOScU93f1JKL7+JBbWDn9uBlboM8BSkAVVd/2vyg88RuEtIru1syxcW +d1rMxqaMRJuhuqaS33CoPUpn30b4zVrPhQJ2+TwDAol4qIGHaie8 +-----END RSA PRIVATE KEY----- diff --git a/vendor/github.com/siddontang/go-mysql/docker/resources/ca.pem b/vendor/github.com/siddontang/go-mysql/docker/resources/ca.pem new file mode 100644 index 0000000..e251bd6 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/resources/ca.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIJANeS1FOzWXlZMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTgwODE2MTUxNDE5WhcNMjEwNjA1MTUxNDE5WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QLsuT1fGf5 +zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eoxrS64GSb +PVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKvlpEAtRAe +4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7IeXUc456 +yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioONaePQKFR +/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABo4GnMIGkMB0GA1UdDgQWBBQgHiwD +00upIbCOunlK4HRw89DhjjB1BgNVHSMEbjBsgBQgHiwD00upIbCOunlK4HRw89Dh +jqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV +BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJANeS1FOzWXlZMAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAFMZFQTFKU5tWIpWh8BbVZeVZcng0Kiq +qwbhVwaTkqtfmbqw8/w+faOWylmLncQEMmgvnUltGMQlQKBwQM2byzPkz9phal3g +uI0JWJYqtcMyIQUB9QbbhrDNC9kdt/ji/x6rrIqzaMRuiBXqH5LQ9h856yXzArqd +cAQGzzYpbUCIv7ciSB93cKkU73fQLZVy5ZBy1+oAa1V9U4cb4G/20/PDmT+G3Gxz +pEjeDKtz8XINoWgA2cSdfAhNZt5vqJaCIZ8qN0z6C7SUKwUBderERUMLUXdhUldC +KTVHyEPvd0aULd5S5vEpKCnHcQmFcLdoN8t9k9pR9ZgwqXbyJHlxWFo= +-----END CERTIFICATE----- diff --git a/vendor/github.com/siddontang/go-mysql/docker/resources/client-cert.pem b/vendor/github.com/siddontang/go-mysql/docker/resources/client-cert.pem new file mode 100644 index 0000000..e478e78 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/resources/client-cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe4CCQDg06wCf7hcuTANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDY0N1oXDTI4MDgxNjA4NDY0N1owRTELMAkG +A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0 +IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AMmivNyk3Rc1ZvLPhb3WPNkf9f2G4g9nMc0+eMrR1IKJ1U1A98ojeIBT+pfk1bSq +Ol0UDm66Vd3YQ+4HpyYHaYV6mwoTEulL9Quk8RLa7TRwQu3PLi3o567RhVIrx8Z3 +umuWb9UUzJfSFH04Uy9+By4CJCqIQXU4BocLIKHhIkNjmAQ9fWO1hZ8zmPHSEfvu +Wqa/DYKGvF0MJr4Lnkm/sKUd+O94p9suvwM6OGIDibACiKRF2H+JbgQLbA58zkLv +DHtXOqsCL7HxiONX8VDrQjN/66Nh9omk/Bx2Ec8IqappHvWf768HSH79x/znaial +VEV+6K0gP+voJHfnA10laWMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAPD+Fn1qj +HN62GD3eIgx6wJxYuemhdbgmEwrZZf4V70lS6e9Iloif0nBiISDxJUpXVWNRCN3Z +3QVC++F7deDmWL/3dSpXRvWsapzbCUhVQ2iBcnZ7QCOdvAqYR1ecZx70zvXCwBcd +6XKmRtdeNV6B211KRFmTYtVyPq4rcWrkTPGwPBncJI1eQQmyFv2T9SwVVp96Nbrq +sf7zrJGmuVCdXGPRi/ALVHtJCz6oPoft3I707eMe+ijnFqwGbmMD4fMD6Ync/hEz +PyR5FMZkXSXHS0gkA5pfwW7wJ2WSWDhI6JMS1gbatY7QzgHbKoQpxBPUXlnzzj2h +7O9cgFTh/XOZXQ== +-----END CERTIFICATE----- diff --git a/vendor/github.com/siddontang/go-mysql/docker/resources/client-key.pem b/vendor/github.com/siddontang/go-mysql/docker/resources/client-key.pem new file mode 100644 index 0000000..996a97b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/resources/client-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAyaK83KTdFzVm8s+FvdY82R/1/YbiD2cxzT54ytHUgonVTUD3 +yiN4gFP6l+TVtKo6XRQObrpV3dhD7genJgdphXqbChMS6Uv1C6TxEtrtNHBC7c8u +LejnrtGFUivHxne6a5Zv1RTMl9IUfThTL34HLgIkKohBdTgGhwsgoeEiQ2OYBD19 +Y7WFnzOY8dIR++5apr8Ngoa8XQwmvgueSb+wpR3473in2y6/Azo4YgOJsAKIpEXY +f4luBAtsDnzOQu8Me1c6qwIvsfGI41fxUOtCM3/ro2H2iaT8HHYRzwipqmke9Z/v +rwdIfv3H/OdqJqVURX7orSA/6+gkd+cDXSVpYwIDAQABAoIBAAGLY5L1GFRzLkSx +3j5kA7dODV5RyC2CBtmhnt8+2DffwmiDFOLRfrzM5+B9+j0WCLhpzOqANuQqIesS +1+7so5xIIiPjnYN393qNWuNgFe0O5xRXP+1OGWg3ZqQIfdFBXYYxcs3ZCPAoxctn +wQteFcP+dDR3MrkpIrOqHCfhR5foieOMP+9k5kCjk+aZqhEmFyko+X+xVO/32xs+ ++3qXhUrHt3Op5on30QMOFguniQlYwLJkd9qVjGuGMIrVPxoUz0rya4SKrGKgkAr8 +mvQe2+sZo7cc6zC2ceaGMJU7z1RalTrCObbg5mynlu+Vf0E/YiES0abkQhSbcSB9 +mAkJC7ECgYEA/H1NDEiO164yYK9ji4HM/8CmHegWS4qsgrzAs8lU0yAcgdg9e19A +mNi8yssfIBCw62RRE4UGWS5F82myhmvq/mXbf8eCJ2CMgdCHQh1rT7WFD/Uc5Pe/ +8Lv2jNMQ61POguPyq6D0qcf8iigKIMHa1MIgAOmrgWrxulfbSUhm370CgYEAzHBu +J9p4dAqW32+Hrtv2XE0KUjH72TXr13WErosgeGTfsIW2exXByvLasxOJSY4Wb8oS +OLZ7bgp/EBchAc7my+nF8n5uOJxipWQUB5BoeB9aUJZ9AnWF4RDl94Jlm5PYBG/J +lRXrMtSTTIgmSw3Ft2A1vRMOQaHX89lNwOZL758CgYAXOT84/gOFexRPKFKzpkDA +1WtyHMLQN/UeIVZoMwCGWtHEb6tYCa7bYDQdQwmd3Wsoe5WpgfbPhR4SAYrWKl72 +/09tNWCXVp4V4qRORH52Wm/ew+Dgfpk8/0zyLwfDXXYFPAo6Fxfp9ecYng4wbSQ/ +pYtkChooUTniteoJl4s+0QKBgHbFEpoAqF3yEPi52L/TdmrlLwvVkhT86IkB8xVc +Kn8HS5VH+V3EpBN9x2SmAupCq/JCGRftnAOwAWWdqkVcqGTq6V8Z6HrnD8A6RhCm +6qpuvI94/iNBl4fLw25pyRH7cFITh68fTsb3DKQ3rNeJpsYEFPRFb9Ddb5JxOmTI +5nDNAoGBAM+SyOhUGU+0Uw2WJaGWzmEutjeMRr5Z+cZ8keC/ZJNdji/faaQoeOQR +OXI8O6RBTBwVNQMyDyttT8J8BkISwfAhSdPkjgPw9GZ1pGREl53uCFDIlX2nvtQM +ioNzG5WHB7Gd7eUUTA91kRF9MZJTHPqNiNGR0Udj/trGyGqJebni +-----END RSA PRIVATE KEY----- diff --git a/vendor/github.com/siddontang/go-mysql/docker/resources/server-cert.pem b/vendor/github.com/siddontang/go-mysql/docker/resources/server-cert.pem new file mode 100644 index 0000000..3cb3b9c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/resources/server-cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe4CCQDg06wCf7hcuDANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDUyNVoXDTI4MDgxNjA4NDUyNVowRTELMAkG +A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0 +IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALK2gqK4uvTlxJANO2JKdibvmh899z6oCo9Km0mz5unj4dpnq9hljsQuKtcHUcM4 +HXcE06knaJ4TOF7lcsjaqoDO7r/SaFgjjXCqNvHD0Su4B+7qe52BZZTRV1AANP10 +PvebarXSEzaZUCyHHhSF8+Qb4vX04XKX/TOqinTVGtlnduKzP5+qsaFBtpLAw1V0 +At9EQB5BgnTYtdIsmvD4/2WhBvOjVxab75yx0R4oof4F3u528tbEegcWhBtmy2Xd +HI3S+TLljj3kOOdB+pgrVUl+KaDavWK3T+F1vTNDe56HEVNKeWlLy1scul61E0j9 +IkZAu6aRDxtKdl7bKu0BkzMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAma3yFqR7 +xkeaZBg4/1I3jSlaNe5+2JB4iybAkMOu77fG5zytLomTbzdhewsuBwpTVMJdga8T +IdPeIFCin1U+5SkbjSMlpKf+krE+5CyrNJ5jAgO9ATIqx66oCTYXfGlNapGRLfSE +sa0iMqCe/dr4GPU+flW2DZFWiyJVDSF1JjReQnfrWY+SD2SpP/lmlgltnY8MJngd +xBLG5nsZCpUXGB713Q8ZyIm2ThVAMiskcxBleIZDDghLuhGvY/9eFJhZpvOkjWa6 +XGEi4E1G/SA+zVKFl41nHKCdqXdmIOnpcLlFBUVloQok5a95Kqc1TYw3f+WbdFff +99dAgk3gWwWZQA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/vendor/github.com/siddontang/go-mysql/docker/resources/server-key.pem b/vendor/github.com/siddontang/go-mysql/docker/resources/server-key.pem new file mode 100644 index 0000000..babaaae --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/resources/server-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAsraCori69OXEkA07Ykp2Ju+aHz33PqgKj0qbSbPm6ePh2mer +2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqqgM7uv9JoWCONcKo28cPRK7gH7up7 +nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz5Bvi9fThcpf9M6qKdNUa2Wd24rM/ +n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ZaEG86NXFpvvnLHRHiih/gXe7nby +1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtVSX4poNq9YrdP4XW9M0N7nocRU0p5 +aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGTMwIDAQABAoIBAGh1m8hHWCg7gXh9 +838RbRx3IswuKS27hWiaQEiFWmzOIb7KqDy1qAxtu+ayRY1paHegH6QY/+Kd824s +ibpzbgQacJ04/HrAVTVMmQ8Z2VLHoAN7lcPL1bd14aZGaLLZVtDeTDJ413grhxxv +4ho27gcgcbo4Z+rWgk7H2WRPCAGYqWYAycm3yF5vy9QaO6edU+T588YsEQOos5iy +5pVFSGDGZkcUp1ukL3BJYR+jvygn6WPCobQ/LScUdi+ucitaI9i+UdlLokZARVRG +M/msqcTM73thR8yVRcexU6NUDxRBfZ/f7moSAEbBmGDXuxDcIyH9KGMQ2rMtN1X3 +lK8UNwkCgYEA2STJq/IUQHjdqd3Dqh/Q7Zm8/pMWFqLJSkqpnFtFsXPyUOx9zDOy +KqkIfGeyKwvsj9X9BcZ0FUKj9zoct1/WpPY+h7i7+z0MIujBh4AMjAcDrt4o76yK +UHuVmG2xKTdJoAbqOdToQeX6E82Ioal5pbB2W7AbCQScNBPZ52jxgtcCgYEA0rE7 +2dFiRm0YmuszFYxft2+GP6NgP3R2TQNEooi1uCXG2xgwObie1YCHzpZ5CfSqJIxP +XB7DXpIWi7PxJoeai2F83LnmdFz6F1BPRobwDoSFNdaSKLg4Yf856zpgYNKhL1fE +OoOXj4VBWBZh1XDfZV44fgwlMIf7edOF1XOagwUCgYAw953O+7FbdKYwF0V3iOM5 +oZDAK/UwN5eC/GFRVDfcM5RycVJRCVtlSWcTfuLr2C2Jpiz/72fgH34QU3eEVsV1 +v94MBznFB1hESw7ReqvZq/9FoO3EVrl+OtBaZmosLD6bKtQJJJ0Xtz/01UW5hxla +pveZ55XBK9v51nwuNjk4UwKBgHD8fJUllSchUCWb5cwzeAz98Kdl7LJ6uQo5q2/i +EllLYOWThiEeIYdrIuklholRPIDXAaPsF2c6vn5yo+q+o6EFSZlw0+YpCjDAb5Lp +wAh5BprFk6HkkM/0t9Guf4rMyYWC8odSlE9x7YXYkuSMYDCTI4Zs6vCoq7I8PbQn +B4AlAoGAZ6Ee5m/ph5UVp/3+cR6jCY7aHBUU/M3pbJSkVjBW+ymEBVJ6sUdz8k3P +x8BiPEQggNN7faWBqRWP7KXPnDYHh6shYUgPJwI5HX6NE/ZDnnXjeysHRyf0oCo5 +S6tHXwHNKB5HS1c/KDyyNGjP2oi/MF4o/MGWNWEcK6TJA3RGOYM= +-----END RSA PRIVATE KEY----- diff --git a/vendor/github.com/siddontang/go-mysql/driver/dirver_test.go b/vendor/github.com/siddontang/go-mysql/driver/dirver_test.go index 54a72cb..d43580f 100644 --- a/vendor/github.com/siddontang/go-mysql/driver/dirver_test.go +++ b/vendor/github.com/siddontang/go-mysql/driver/dirver_test.go @@ -11,6 +11,11 @@ import ( // Use docker mysql to test, mysql is 3306 var testHost = flag.String("host", "127.0.0.1", "MySQL master host") +// possible choices for different MySQL versions are: 5561,5641,3306,5722,8003,8012 +var testPort = flag.Int("port", 3306, "MySQL server port") +var testUser = flag.String("user", "root", "MySQL user") +var testPassword = flag.String("pass", "", "MySQL password") +var testDB = flag.String("db", "test", "MySQL test database") func TestDriver(t *testing.T) { TestingT(t) @@ -23,7 +28,8 @@ type testDriverSuite struct { var _ = Suite(&testDriverSuite{}) func (s *testDriverSuite) SetUpSuite(c *C) { - dsn := fmt.Sprintf("root@%s:3306?test", *testHost) + addr := fmt.Sprintf("%s:%d", *testHost, *testPort) + dsn := fmt.Sprintf("%s:%s@%s?%s", *testUser, *testPassword, addr, *testDB) var err error s.db, err = sqlx.Open("mysql", dsn) diff --git a/vendor/github.com/siddontang/go-mysql/driver/driver.go b/vendor/github.com/siddontang/go-mysql/driver/driver.go index 6263929..e131548 100644 --- a/vendor/github.com/siddontang/go-mysql/driver/driver.go +++ b/vendor/github.com/siddontang/go-mysql/driver/driver.go @@ -20,7 +20,8 @@ type driver struct { // DSN user:password@addr[?db] func (d driver) Open(dsn string) (sqldriver.Conn, error) { - seps := strings.Split(dsn, "@") + lastIndex := strings.LastIndex(dsn, "@") + seps := []string{dsn[:lastIndex], dsn[lastIndex+1:]} if len(seps) != 2 { return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]") } diff --git a/vendor/github.com/siddontang/go-mysql/dump/dump.go b/vendor/github.com/siddontang/go-mysql/dump/dump.go index a6ff209..1f8384d 100644 --- a/vendor/github.com/siddontang/go-mysql/dump/dump.go +++ b/vendor/github.com/siddontang/go-mysql/dump/dump.go @@ -8,6 +8,8 @@ import ( "strings" "github.com/juju/errors" + "github.com/siddontang/go-log/log" + . "github.com/siddontang/go-mysql/mysql" ) // Unlick mysqldump, Dumper is designed for parsing and syning data easily. @@ -25,9 +27,16 @@ type Dumper struct { Databases []string + Where string + Charset string + IgnoreTables map[string][]string ErrOut io.Writer + + masterDataSkipped bool + maxAllowedPacket int + hexBlob bool } func NewDumper(executionPath string, addr string, user string, password string) (*Dumper, error) { @@ -47,17 +56,40 @@ func NewDumper(executionPath string, addr string, user string, password string) d.Password = password d.Tables = make([]string, 0, 16) d.Databases = make([]string, 0, 16) + d.Charset = DEFAULT_CHARSET d.IgnoreTables = make(map[string][]string) + d.masterDataSkipped = false d.ErrOut = os.Stderr return d, nil } +func (d *Dumper) SetCharset(charset string) { + d.Charset = charset +} + +func (d *Dumper) SetWhere(where string) { + d.Where = where +} + func (d *Dumper) SetErrOut(o io.Writer) { d.ErrOut = o } +// In some cloud MySQL, we have no privilege to use `--master-data`. +func (d *Dumper) SkipMasterData(v bool) { + d.masterDataSkipped = v +} + +func (d *Dumper) SetMaxAllowedPacket(i int) { + d.maxAllowedPacket = i +} + +func (d *Dumper) SetHexBlob(v bool) { + d.hexBlob = v +} + func (d *Dumper) AddDatabases(dbs ...string) { d.Databases = append(d.Databases, dbs...) } @@ -82,22 +114,35 @@ func (d *Dumper) Reset() { d.TableDB = "" d.IgnoreTables = make(map[string][]string) d.Databases = d.Databases[0:0] + d.Where = "" } func (d *Dumper) Dump(w io.Writer) error { args := make([]string, 0, 16) // Common args - seps := strings.Split(d.Addr, ":") - args = append(args, fmt.Sprintf("--host=%s", seps[0])) - if len(seps) > 1 { - args = append(args, fmt.Sprintf("--port=%s", seps[1])) + if strings.Contains(d.Addr, "/") { + args = append(args, fmt.Sprintf("--socket=%s", d.Addr)) + } else { + seps := strings.SplitN(d.Addr, ":", 2) + args = append(args, fmt.Sprintf("--host=%s", seps[0])) + if len(seps) > 1 { + args = append(args, fmt.Sprintf("--port=%s", seps[1])) + } } args = append(args, fmt.Sprintf("--user=%s", d.User)) args = append(args, fmt.Sprintf("--password=%s", d.Password)) - args = append(args, "--master-data") + if !d.masterDataSkipped { + args = append(args, "--master-data") + } + + if d.maxAllowedPacket > 0 { + // mysqldump param should be --max-allowed-packet=%dM not be --max_allowed_packet=%dM + args = append(args, fmt.Sprintf("--max-allowed-packet=%dM", d.maxAllowedPacket)) + } + args = append(args, "--single-transaction") args = append(args, "--skip-lock-tables") @@ -112,12 +157,25 @@ func (d *Dumper) Dump(w io.Writer) error { // Multi row is easy for us to parse the data args = append(args, "--skip-extended-insert") + if d.hexBlob { + // Use hex for the binary type + args = append(args, "--hex-blob") + } + for db, tables := range d.IgnoreTables { for _, table := range tables { args = append(args, fmt.Sprintf("--ignore-table=%s.%s", db, table)) } } + if len(d.Charset) != 0 { + args = append(args, fmt.Sprintf("--default-character-set=%s", d.Charset)) + } + + if len(d.Where) != 0 { + args = append(args, fmt.Sprintf("--where=%s", d.Where)) + } + if len(d.Tables) == 0 && len(d.Databases) == 0 { args = append(args, "--all-databases") } else if len(d.Tables) == 0 { @@ -133,6 +191,7 @@ func (d *Dumper) Dump(w io.Writer) error { w.Write([]byte(fmt.Sprintf("USE `%s`;\n", d.TableDB))) } + log.Infof("exec mysqldump with %v", args) cmd := exec.Command(d.ExecutionPath, args...) cmd.Stderr = d.ErrOut @@ -147,7 +206,7 @@ func (d *Dumper) DumpAndParse(h ParseHandler) error { done := make(chan error, 1) go func() { - err := Parse(r, h) + err := Parse(r, h, !d.masterDataSkipped) r.CloseWithError(err) done <- err }() diff --git a/vendor/github.com/siddontang/go-mysql/dump/dump_test.go b/vendor/github.com/siddontang/go-mysql/dump/dump_test.go index 39e430f..eed4c75 100644 --- a/vendor/github.com/siddontang/go-mysql/dump/dump_test.go +++ b/vendor/github.com/siddontang/go-mysql/dump/dump_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "os" + "strings" "testing" . "github.com/pingcap/check" @@ -38,6 +39,7 @@ func (s *schemaTestSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) c.Assert(s.d, NotNil) + s.d.SetCharset("utf8") s.d.SetErrOut(os.Stderr) _, err = s.conn.Execute("CREATE DATABASE IF NOT EXISTS test1") @@ -177,7 +179,7 @@ func (s *schemaTestSuite) TestParse(c *C) { err := s.d.Dump(&buf) c.Assert(err, IsNil) - err = Parse(&buf, new(testParseHandler)) + err = Parse(&buf, new(testParseHandler), true) c.Assert(err, IsNil) } @@ -196,3 +198,29 @@ func (s *parserTestSuite) TestParseValue(c *C) { values, err = parseValues(str) c.Assert(err, NotNil) } + +func (s *parserTestSuite) TestParseLine(c *C) { + lines := []struct { + line string + expected string + }{ + {line: "INSERT INTO `test` VALUES (1, 'first', 'hello mysql; 2', 'e1', 'a,b');", + expected: "1, 'first', 'hello mysql; 2', 'e1', 'a,b'"}, + {line: "INSERT INTO `test` VALUES (0x22270073646661736661736466, 'first', 'hello mysql; 2', 'e1', 'a,b');", + expected: "0x22270073646661736661736466, 'first', 'hello mysql; 2', 'e1', 'a,b'"}, + } + + f := func(c rune) bool { + return c == '\r' || c == '\n' + } + + for _, t := range lines { + l := strings.TrimRightFunc(t.line, f) + + m := valuesExp.FindAllStringSubmatch(l, -1) + + c.Assert(m, HasLen, 1) + c.Assert(m[0][1], Matches, "test") + c.Assert(m[0][2], Matches, t.expected) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/dump/parser.go b/vendor/github.com/siddontang/go-mysql/dump/parser.go index 69f65c7..ad40925 100644 --- a/vendor/github.com/siddontang/go-mysql/dump/parser.go +++ b/vendor/github.com/siddontang/go-mysql/dump/parser.go @@ -6,6 +6,7 @@ import ( "io" "regexp" "strconv" + "strings" "github.com/juju/errors" "github.com/siddontang/go-mysql/mysql" @@ -34,7 +35,7 @@ func init() { // Parse the dump data with Dumper generate. // It can not parse all the data formats with mysqldump outputs -func Parse(r io.Reader, h ParseHandler) error { +func Parse(r io.Reader, h ParseHandler, parseBinlogPos bool) error { rb := bufio.NewReaderSize(r, 1024*16) var db string @@ -48,9 +49,12 @@ func Parse(r io.Reader, h ParseHandler) error { break } - line = line[0 : len(line)-1] + // Ignore '\n' on Linux or '\r\n' on Windows + line = strings.TrimRightFunc(line, func(c rune) bool { + return c == '\r' || c == '\n' + }) - if !binlogParsed { + if parseBinlogPos && !binlogParsed { if m := binlogExp.FindAllStringSubmatch(line, -1); len(m) == 1 { name := m[0][1] pos, err := strconv.ParseUint(m[0][2], 10, 64) diff --git a/vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go b/vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go index 5c8bd64..2798241 100644 --- a/vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go +++ b/vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go @@ -46,12 +46,12 @@ func (h *MariadbGTIDHandler) FindBestSlaves(slaves []*Server) ([]*Server, error) if len(str) == 0 { seq = 0 } else { - g, err := ParseMariadbGTIDSet(str) + g, err := ParseMariadbGTID(str) if err != nil { return nil, errors.Trace(err) } - seq = g.(MariadbGTID).SequenceNumber + seq = g.SequenceNumber } ps[i] = seq @@ -118,7 +118,7 @@ func (h *MariadbGTIDHandler) WaitRelayLogDone(s *Server) error { fname, _ := r.GetStringByName(0, "Master_Log_File") pos, _ := r.GetIntByName(0, "Read_Master_Log_Pos") - return s.MasterPosWait(Position{fname, uint32(pos)}, 0) + return s.MasterPosWait(Position{Name: fname, Pos: uint32(pos)}, 0) } func (h *MariadbGTIDHandler) WaitCatchMaster(s *Server, m *Server) error { diff --git a/vendor/github.com/siddontang/go-mysql/failover/server.go b/vendor/github.com/siddontang/go-mysql/failover/server.go index ff87464..c02d6c8 100644 --- a/vendor/github.com/siddontang/go-mysql/failover/server.go +++ b/vendor/github.com/siddontang/go-mysql/failover/server.go @@ -152,7 +152,7 @@ func (s *Server) FetchSlaveReadPos() (Position, error) { fname, _ := r.GetStringByName(0, "Master_Log_File") pos, _ := r.GetIntByName(0, "Read_Master_Log_Pos") - return Position{fname, uint32(pos)}, nil + return Position{Name: fname, Pos: uint32(pos)}, nil } // Get current executed binlog filename and position from master @@ -165,7 +165,7 @@ func (s *Server) FetchSlaveExecutePos() (Position, error) { fname, _ := r.GetStringByName(0, "Relay_Master_Log_File") pos, _ := r.GetIntByName(0, "Exec_Master_Log_Pos") - return Position{fname, uint32(pos)}, nil + return Position{Name: fname, Pos: uint32(pos)}, nil } func (s *Server) MasterPosWait(pos Position, timeout int) error { diff --git a/vendor/github.com/siddontang/go-mysql/glide.lock b/vendor/github.com/siddontang/go-mysql/glide.lock deleted file mode 100644 index 7d71953..0000000 --- a/vendor/github.com/siddontang/go-mysql/glide.lock +++ /dev/null @@ -1,30 +0,0 @@ -hash: 1a3d05afef96cd7601a004e573b128db9051eecd1b5d0a3d69d3fa1ee1a3e3b8 -updated: 2016-09-03T12:30:00.028685232+08:00 -imports: -- name: github.com/BurntSushi/toml - version: 056c9bc7be7190eaa7715723883caffa5f8fa3e4 -- name: github.com/go-sql-driver/mysql - version: 3654d25ec346ee8ce71a68431025458d52a38ac0 -- name: github.com/jmoiron/sqlx - version: 54aec3fd91a2b2129ffaca0d652b8a9223ee2d9e - subpackages: - - reflectx -- name: github.com/juju/errors - version: 6f54ff6318409d31ff16261533ce2c8381a4fd5d -- name: github.com/ngaut/log - version: cec23d3e10b016363780d894a0eb732a12c06e02 -- name: github.com/pingcap/check - version: ce8a2f822ab1e245a4eefcef2996531c79c943f1 -- name: github.com/satori/go.uuid - version: 879c5887cd475cd7864858769793b2ceb0d44feb -- name: github.com/siddontang/go - version: 354e14e6c093c661abb29fd28403b3c19cff5514 - subpackages: - - hack - - ioutil2 - - sync2 -- name: golang.org/x/net - version: 6acef71eb69611914f7a30939ea9f6e194c78172 - subpackages: - - context -testImports: [] diff --git a/vendor/github.com/siddontang/go-mysql/glide.yaml b/vendor/github.com/siddontang/go-mysql/glide.yaml deleted file mode 100644 index 3561648..0000000 --- a/vendor/github.com/siddontang/go-mysql/glide.yaml +++ /dev/null @@ -1,26 +0,0 @@ -package: github.com/siddontang/go-mysql -import: -- package: github.com/BurntSushi/toml - version: 056c9bc7be7190eaa7715723883caffa5f8fa3e4 -- package: github.com/go-sql-driver/mysql - version: 3654d25ec346ee8ce71a68431025458d52a38ac0 -- package: github.com/jmoiron/sqlx - version: 54aec3fd91a2b2129ffaca0d652b8a9223ee2d9e - subpackages: - - reflectx -- package: github.com/juju/errors - version: 6f54ff6318409d31ff16261533ce2c8381a4fd5d -- package: github.com/ngaut/log - version: cec23d3e10b016363780d894a0eb732a12c06e02 -- package: github.com/pingcap/check - version: ce8a2f822ab1e245a4eefcef2996531c79c943f1 -- package: github.com/satori/go.uuid - version: ^1.1.0 -- package: github.com/siddontang/go - version: 354e14e6c093c661abb29fd28403b3c19cff5514 - subpackages: - - hack - - ioutil2 - - sync2 -- package: golang.org/x/net - version: 6acef71eb69611914f7a30939ea9f6e194c78172 diff --git a/vendor/github.com/siddontang/go-mysql/mysql/const.go b/vendor/github.com/siddontang/go-mysql/mysql/const.go index 2f4ab63..256d163 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/const.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/const.go @@ -6,16 +6,22 @@ const ( TimeFormat string = "2006-01-02 15:04:05" ) -var ( - // maybe you can change for your specified name - ServerVersion string = "go-mysql-0.1" -) - const ( OK_HEADER byte = 0x00 + MORE_DATE_HEADER byte = 0x01 ERR_HEADER byte = 0xff EOF_HEADER byte = 0xfe LocalInFile_HEADER byte = 0xfb + + CACHE_SHA2_FAST_AUTH byte = 0x03 + CACHE_SHA2_FULL_AUTH byte = 0x04 +) + +const ( + AUTH_MYSQL_OLD_PASSWORD = "mysql_old_password" + AUTH_NATIVE_PASSWORD = "mysql_native_password" + AUTH_CACHING_SHA2_PASSWORD = "caching_sha2_password" + AUTH_SHA256_PASSWORD = "sha256_password" ) const ( @@ -151,7 +157,6 @@ const ( ) const ( - AUTH_NAME = "mysql_native_password" DEFAULT_CHARSET = "utf8" DEFAULT_COLLATION_ID uint8 = 33 DEFAULT_COLLATION_NAME string = "utf8_general_ci" diff --git a/vendor/github.com/siddontang/go-mysql/mysql/error.go b/vendor/github.com/siddontang/go-mysql/mysql/error.go index 227e70e..876a408 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/error.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/error.go @@ -57,3 +57,10 @@ func NewError(errCode uint16, message string) *MyError { return e } + +func ErrorCode(errMsg string) (code int) { + var tmpStr string + // golang scanf doesn't support %*,so I used a temporary variable + fmt.Sscanf(errMsg, "%s%d", &tmpStr, &code) + return +} diff --git a/vendor/github.com/siddontang/go-mysql/mysql/field.go b/vendor/github.com/siddontang/go-mysql/mysql/field.go index c26f6a2..891f00b 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/field.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/field.go @@ -31,42 +31,42 @@ func (p FieldData) Parse() (f *Field, err error) { var n int pos := 0 //skip catelog, always def - n, err = SkipLengthEnodedString(p) + n, err = SkipLengthEncodedString(p) if err != nil { return } pos += n //schema - f.Schema, _, n, err = LengthEnodedString(p[pos:]) + f.Schema, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //table - f.Table, _, n, err = LengthEnodedString(p[pos:]) + f.Table, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //org_table - f.OrgTable, _, n, err = LengthEnodedString(p[pos:]) + f.OrgTable, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //name - f.Name, _, n, err = LengthEnodedString(p[pos:]) + f.Name, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //org_name - f.OrgName, _, n, err = LengthEnodedString(p[pos:]) + f.OrgName, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } diff --git a/vendor/github.com/siddontang/go-mysql/mysql/gtid.go b/vendor/github.com/siddontang/go-mysql/mysql/gtid.go index db0d638..cde9901 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/gtid.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/gtid.go @@ -11,6 +11,10 @@ type GTIDSet interface { Equal(o GTIDSet) bool Contain(o GTIDSet) bool + + Update(GTIDStr string) error + + Clone() GTIDSet } func ParseGTIDSet(flavor string, s string) (GTIDSet, error) { diff --git a/vendor/github.com/siddontang/go-mysql/mysql/mariadb_gtid.go b/vendor/github.com/siddontang/go-mysql/mysql/mariadb_gtid.go index ea89458..09fe7ac 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/mariadb_gtid.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/mariadb_gtid.go @@ -1,28 +1,32 @@ package mysql import ( + "bytes" "fmt" "strconv" "strings" "github.com/juju/errors" + "github.com/siddontang/go-log/log" + "github.com/siddontang/go/hack" ) +// MariadbGTID represent mariadb gtid, [domain ID]-[server-id]-[sequence] type MariadbGTID struct { DomainID uint32 ServerID uint32 SequenceNumber uint64 } -// We don't support multi source replication, so the mariadb gtid set may have only domain-server-sequence -func ParseMariadbGTIDSet(str string) (GTIDSet, error) { +// ParseMariadbGTID parses mariadb gtid, [domain ID]-[server-id]-[sequence] +func ParseMariadbGTID(str string) (*MariadbGTID, error) { if len(str) == 0 { - return MariadbGTID{0, 0, 0}, nil + return &MariadbGTID{0, 0, 0}, nil } seps := strings.Split(str, "-") - var gtid MariadbGTID + gtid := new(MariadbGTID) if len(seps) != 3 { return gtid, errors.Errorf("invalid Mariadb GTID %v, must domain-server-sequence", str) @@ -43,13 +47,13 @@ func ParseMariadbGTIDSet(str string) (GTIDSet, error) { return gtid, errors.Errorf("invalid MariaDB GTID Sequence number (%v): %v", seps[2], err) } - return MariadbGTID{ + return &MariadbGTID{ DomainID: uint32(domainID), ServerID: uint32(serverID), SequenceNumber: sequenceID}, nil } -func (gtid MariadbGTID) String() string { +func (gtid *MariadbGTID) String() string { if gtid.DomainID == 0 && gtid.ServerID == 0 && gtid.SequenceNumber == 0 { return "" } @@ -57,24 +61,172 @@ func (gtid MariadbGTID) String() string { return fmt.Sprintf("%d-%d-%d", gtid.DomainID, gtid.ServerID, gtid.SequenceNumber) } -func (gtid MariadbGTID) Encode() []byte { - return []byte(gtid.String()) -} - -func (gtid MariadbGTID) Equal(o GTIDSet) bool { - other, ok := o.(MariadbGTID) - if !ok { - return false - } - - return gtid == other -} - -func (gtid MariadbGTID) Contain(o GTIDSet) bool { - other, ok := o.(MariadbGTID) - if !ok { - return false - } - +// Contain return whether one mariadb gtid covers another mariadb gtid +func (gtid *MariadbGTID) Contain(other *MariadbGTID) bool { return gtid.DomainID == other.DomainID && gtid.SequenceNumber >= other.SequenceNumber } + +// Clone clones a mariadb gtid +func (gtid *MariadbGTID) Clone() *MariadbGTID { + o := new(MariadbGTID) + *o = *gtid + return o +} + +func (gtid *MariadbGTID) forward(newer *MariadbGTID) error { + if newer.DomainID != gtid.DomainID { + return errors.Errorf("%s is not same with doamin of %s", newer, gtid) + } + + /* + Here's a simplified example of binlog events. + Although I think one domain should have only one update at same time, we can't limit the user's usage. + we just output a warn log and let it go on + | mysqld-bin.000001 | 1453 | Gtid | 112 | 1495 | BEGIN GTID 0-112-6 | + | mysqld-bin.000001 | 1624 | Xid | 112 | 1655 | COMMIT xid=74 | + | mysqld-bin.000001 | 1655 | Gtid | 112 | 1697 | BEGIN GTID 0-112-7 | + | mysqld-bin.000001 | 1826 | Xid | 112 | 1857 | COMMIT xid=75 | + | mysqld-bin.000001 | 1857 | Gtid | 111 | 1899 | BEGIN GTID 0-111-5 | + | mysqld-bin.000001 | 1981 | Xid | 111 | 2012 | COMMIT xid=77 | + | mysqld-bin.000001 | 2012 | Gtid | 112 | 2054 | BEGIN GTID 0-112-8 | + | mysqld-bin.000001 | 2184 | Xid | 112 | 2215 | COMMIT xid=116 | + | mysqld-bin.000001 | 2215 | Gtid | 111 | 2257 | BEGIN GTID 0-111-6 | + */ + if newer.SequenceNumber <= gtid.SequenceNumber { + log.Warnf("out of order binlog appears with gtid %s vs current position gtid %s", newer, gtid) + } + + gtid.ServerID = newer.ServerID + gtid.SequenceNumber = newer.SequenceNumber + return nil +} + +// MariadbGTIDSet is a set of mariadb gtid +type MariadbGTIDSet struct { + Sets map[uint32]*MariadbGTID +} + +// ParseMariadbGTIDSet parses str into mariadb gtid sets +func ParseMariadbGTIDSet(str string) (GTIDSet, error) { + s := new(MariadbGTIDSet) + s.Sets = make(map[uint32]*MariadbGTID) + if str == "" { + return s, nil + } + + sp := strings.Split(str, ",") + + //todo, handle redundant same uuid + for i := 0; i < len(sp); i++ { + err := s.Update(sp[i]) + if err != nil { + return nil, errors.Trace(err) + } + } + return s, nil +} + +// AddSet adds mariadb gtid into mariadb gtid set +func (s *MariadbGTIDSet) AddSet(gtid *MariadbGTID) error { + if gtid == nil { + return nil + } + + o, ok := s.Sets[gtid.DomainID] + if ok { + err := o.forward(gtid) + if err != nil { + return errors.Trace(err) + } + } else { + s.Sets[gtid.DomainID] = gtid + } + + return nil +} + +// Update updates mariadb gtid set +func (s *MariadbGTIDSet) Update(GTIDStr string) error { + gtid, err := ParseMariadbGTID(GTIDStr) + if err != nil { + return err + } + + err = s.AddSet(gtid) + return errors.Trace(err) +} + +func (s *MariadbGTIDSet) String() string { + return hack.String(s.Encode()) +} + +// Encode encodes mariadb gtid set +func (s *MariadbGTIDSet) Encode() []byte { + var buf bytes.Buffer + sep := "" + for _, gtid := range s.Sets { + buf.WriteString(sep) + buf.WriteString(gtid.String()) + sep = "," + } + + return buf.Bytes() +} + +// Clone clones a mariadb gtid set +func (s *MariadbGTIDSet) Clone() GTIDSet { + clone := &MariadbGTIDSet{ + Sets: make(map[uint32]*MariadbGTID), + } + for domainID, gtid := range s.Sets { + clone.Sets[domainID] = gtid.Clone() + } + + return clone +} + +// Equal returns true if two mariadb gtid set is same, otherwise return false +func (s *MariadbGTIDSet) Equal(o GTIDSet) bool { + other, ok := o.(*MariadbGTIDSet) + if !ok { + return false + } + + if len(other.Sets) != len(s.Sets) { + return false + } + + for domainID, gtid := range other.Sets { + o, ok := s.Sets[domainID] + if !ok { + return false + } + + if *gtid != *o { + return false + } + } + + return true +} + +// Contain return whether one mariadb gtid set covers another mariadb gtid set +func (s *MariadbGTIDSet) Contain(o GTIDSet) bool { + other, ok := o.(*MariadbGTIDSet) + if !ok { + return false + } + + for doaminID, gtid := range other.Sets { + o, ok := s.Sets[doaminID] + if !ok { + return false + } + + if !o.Contain(gtid) { + return false + } + } + + return true +} diff --git a/vendor/github.com/siddontang/go-mysql/mysql/mariadb_gtid_test.go b/vendor/github.com/siddontang/go-mysql/mysql/mariadb_gtid_test.go new file mode 100644 index 0000000..1455e26 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/mysql/mariadb_gtid_test.go @@ -0,0 +1,234 @@ +package mysql + +import ( + "github.com/pingcap/check" +) + +type mariaDBTestSuite struct { +} + +var _ = check.Suite(&mariaDBTestSuite{}) + +func (t *mariaDBTestSuite) SetUpSuite(c *check.C) { + +} + +func (t *mariaDBTestSuite) TearDownSuite(c *check.C) { + +} + +func (t *mariaDBTestSuite) TestParseMariaDBGTID(c *check.C) { + cases := []struct { + gtidStr string + hashError bool + }{ + {"0-1-1", false}, + {"", false}, + {"0-1-1-1", true}, + {"1", true}, + {"0-1-seq", true}, + } + + for _, cs := range cases { + gtid, err := ParseMariadbGTID(cs.gtidStr) + if cs.hashError { + c.Assert(err, check.NotNil) + } else { + c.Assert(err, check.IsNil) + c.Assert(gtid.String(), check.Equals, cs.gtidStr) + } + } +} + +func (t *mariaDBTestSuite) TestMariaDBGTIDConatin(c *check.C) { + cases := []struct { + originGTIDStr, otherGTIDStr string + contain bool + }{ + {"0-1-1", "0-1-2", false}, + {"0-1-1", "", true}, + {"2-1-1", "1-1-1", false}, + {"1-2-1", "1-1-1", true}, + {"1-2-2", "1-1-1", true}, + } + + for _, cs := range cases { + originGTID, err := ParseMariadbGTID(cs.originGTIDStr) + c.Assert(err, check.IsNil) + otherGTID, err := ParseMariadbGTID(cs.otherGTIDStr) + c.Assert(err, check.IsNil) + + c.Assert(originGTID.Contain(otherGTID), check.Equals, cs.contain) + } +} + +func (t *mariaDBTestSuite) TestMariaDBGTIDClone(c *check.C) { + gtid, err := ParseMariadbGTID("1-1-1") + c.Assert(err, check.IsNil) + + clone := gtid.Clone() + c.Assert(gtid, check.DeepEquals, clone) +} + +func (t *mariaDBTestSuite) TestMariaDBForward(c *check.C) { + cases := []struct { + currentGTIDStr, newerGTIDStr string + hashError bool + }{ + {"0-1-1", "0-1-2", false}, + {"0-1-1", "", false}, + {"2-1-1", "1-1-1", true}, + {"1-2-1", "1-1-1", false}, + {"1-2-2", "1-1-1", false}, + } + + for _, cs := range cases { + currentGTID, err := ParseMariadbGTID(cs.currentGTIDStr) + c.Assert(err, check.IsNil) + newerGTID, err := ParseMariadbGTID(cs.newerGTIDStr) + c.Assert(err, check.IsNil) + + err = currentGTID.forward(newerGTID) + if cs.hashError { + c.Assert(err, check.NotNil) + c.Assert(currentGTID.String(), check.Equals, cs.currentGTIDStr) + } else { + c.Assert(err, check.IsNil) + c.Assert(currentGTID.String(), check.Equals, cs.newerGTIDStr) + } + } +} + +func (t *mariaDBTestSuite) TestParseMariaDBGTIDSet(c *check.C) { + cases := []struct { + gtidStr string + subGTIDs map[uint32]string //domain ID => gtid string + expectedStr []string // test String() + hasError bool + }{ + {"0-1-1", map[uint32]string{0: "0-1-1"}, []string{"0-1-1"}, false}, + {"", nil, []string{""}, false}, + {"0-1-1,1-2-3", map[uint32]string{0: "0-1-1", 1: "1-2-3"}, []string{"0-1-1,1-2-3", "1-2-3,0-1-1"}, false}, + {"0-1--1", nil, nil, true}, + } + + for _, cs := range cases { + gtidSet, err := ParseMariadbGTIDSet(cs.gtidStr) + if cs.hasError { + c.Assert(err, check.NotNil) + } else { + c.Assert(err, check.IsNil) + mariadbGTIDSet, ok := gtidSet.(*MariadbGTIDSet) + c.Assert(ok, check.IsTrue) + + // check sub gtid + c.Assert(mariadbGTIDSet.Sets, check.HasLen, len(cs.subGTIDs)) + for domainID, gtid := range mariadbGTIDSet.Sets { + c.Assert(mariadbGTIDSet.Sets, check.HasKey, domainID) + c.Assert(gtid.String(), check.Equals, cs.subGTIDs[domainID]) + } + + // check String() function + inExpectedResult := false + actualStr := mariadbGTIDSet.String() + for _, str := range cs.expectedStr { + if str == actualStr { + inExpectedResult = true + break + } + } + c.Assert(inExpectedResult, check.IsTrue) + } + } +} + +func (t *mariaDBTestSuite) TestMariaDBGTIDSetUpdate(c *check.C) { + cases := []struct { + isNilGTID bool + gtidStr string + subGTIDs map[uint32]string + }{ + {true, "", map[uint32]string{1: "1-1-1", 2: "2-2-2"}}, + {false, "1-2-2", map[uint32]string{1: "1-2-2", 2: "2-2-2"}}, + {false, "1-2-1", map[uint32]string{1: "1-2-1", 2: "2-2-2"}}, + {false, "3-2-1", map[uint32]string{1: "1-1-1", 2: "2-2-2", 3: "3-2-1"}}, + } + + for _, cs := range cases { + gtidSet, err := ParseMariadbGTIDSet("1-1-1,2-2-2") + c.Assert(err, check.IsNil) + mariadbGTIDSet, ok := gtidSet.(*MariadbGTIDSet) + c.Assert(ok, check.IsTrue) + + if cs.isNilGTID { + c.Assert(mariadbGTIDSet.AddSet(nil), check.IsNil) + } else { + err := gtidSet.Update(cs.gtidStr) + c.Assert(err, check.IsNil) + } + // check sub gtid + c.Assert(mariadbGTIDSet.Sets, check.HasLen, len(cs.subGTIDs)) + for domainID, gtid := range mariadbGTIDSet.Sets { + c.Assert(mariadbGTIDSet.Sets, check.HasKey, domainID) + c.Assert(gtid.String(), check.Equals, cs.subGTIDs[domainID]) + } + } +} + +func (t *mariaDBTestSuite) TestMariaDBGTIDSetEqual(c *check.C) { + cases := []struct { + originGTIDStr, otherGTIDStr string + equals bool + }{ + {"", "", true}, + {"1-1-1", "1-1-1,2-2-2", false}, + {"1-1-1,2-2-2", "1-1-1", false}, + {"1-1-1,2-2-2", "1-1-1,2-2-2", true}, + {"1-1-1,2-2-2", "1-1-1,2-2-3", false}, + } + + for _, cs := range cases { + originGTID, err := ParseMariadbGTIDSet(cs.originGTIDStr) + c.Assert(err, check.IsNil) + + otherGTID, err := ParseMariadbGTIDSet(cs.otherGTIDStr) + c.Assert(err, check.IsNil) + + c.Assert(originGTID.Equal(otherGTID), check.Equals, cs.equals) + } +} + +func (t *mariaDBTestSuite) TestMariaDBGTIDSetContain(c *check.C) { + cases := []struct { + originGTIDStr, otherGTIDStr string + contain bool + }{ + {"", "", true}, + {"1-1-1", "1-1-1,2-2-2", false}, + {"1-1-1,2-2-2", "1-1-1", true}, + {"1-1-1,2-2-2", "1-1-1,2-2-2", true}, + {"1-1-1,2-2-2", "1-1-1,2-2-1", true}, + {"1-1-1,2-2-2", "1-1-1,2-2-3", false}, + } + + for _, cs := range cases { + originGTIDSet, err := ParseMariadbGTIDSet(cs.originGTIDStr) + c.Assert(err, check.IsNil) + + otherGTIDSet, err := ParseMariadbGTIDSet(cs.otherGTIDStr) + c.Assert(err, check.IsNil) + + c.Assert(originGTIDSet.Contain(otherGTIDSet), check.Equals, cs.contain) + } +} + +func (t *mariaDBTestSuite) TestMariaDBGTIDSetClone(c *check.C) { + cases := []string{"", "1-1-1", "1-1-1,2-2-2"} + + for _, str := range cases { + gtidSet, err := ParseMariadbGTIDSet(str) + c.Assert(err, check.IsNil) + + c.Assert(gtidSet.Clone(), check.DeepEquals, gtidSet) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/mysql/mysql_gtid.go b/vendor/github.com/siddontang/go-mysql/mysql/mysql_gtid.go index c54ded0..a937cb8 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/mysql_gtid.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/mysql_gtid.go @@ -97,7 +97,11 @@ func (s IntervalSlice) Normalize() IntervalSlice { n = append(n, s[i]) continue } else { - n[len(n)-1] = Interval{last.Start, s[i].Stop} + stop := s[i].Stop + if last.Stop > stop { + stop = last.Stop + } + n[len(n)-1] = Interval{last.Start, stop} } } @@ -285,17 +289,28 @@ func (s *UUIDSet) Decode(data []byte) error { return err } +func (s *UUIDSet) Clone() *UUIDSet { + clone := new(UUIDSet) + + clone.SID, _ = uuid.FromString(s.SID.String()) + clone.Intervals = s.Intervals.Normalize() + + return clone +} + type MysqlGTIDSet struct { Sets map[string]*UUIDSet } func ParseMysqlGTIDSet(str string) (GTIDSet, error) { s := new(MysqlGTIDSet) + s.Sets = make(map[string]*UUIDSet) + if str == "" { + return s, nil + } sp := strings.Split(str, ",") - s.Sets = make(map[string]*UUIDSet, len(sp)) - //todo, handle redundant same uuid for i := 0; i < len(sp); i++ { if set, err := ParseUUIDSet(sp[i]); err != nil { @@ -334,6 +349,9 @@ func DecodeMysqlGTIDSet(data []byte) (*MysqlGTIDSet, error) { } func (s *MysqlGTIDSet) AddSet(set *UUIDSet) { + if set == nil { + return + } sid := set.SID.String() o, ok := s.Sets[sid] if ok { @@ -343,6 +361,17 @@ func (s *MysqlGTIDSet) AddSet(set *UUIDSet) { } } +func (s *MysqlGTIDSet) Update(GTIDStr string) error { + uuidSet, err := ParseUUIDSet(GTIDStr) + if err != nil { + return err + } + + s.AddSet(uuidSet) + + return nil +} + func (s *MysqlGTIDSet) Contain(o GTIDSet) bool { sub, ok := o.(*MysqlGTIDSet) if !ok { @@ -407,3 +436,14 @@ func (s *MysqlGTIDSet) Encode() []byte { return buf.Bytes() } + +func (gtid *MysqlGTIDSet) Clone() GTIDSet { + clone := &MysqlGTIDSet{ + Sets: make(map[string]*UUIDSet), + } + for sid, uuidSet := range gtid.Sets { + clone.Sets[sid] = uuidSet.Clone() + } + + return clone +} diff --git a/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go b/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go index 8fbaaab..df4b206 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go @@ -1,6 +1,7 @@ package mysql import ( + "strings" "testing" "github.com/pingcap/check" @@ -15,11 +16,11 @@ type mysqlTestSuite struct { var _ = check.Suite(&mysqlTestSuite{}) -func (s *mysqlTestSuite) SetUpSuite(c *check.C) { +func (t *mysqlTestSuite) SetUpSuite(c *check.C) { } -func (s *mysqlTestSuite) TearDownSuite(c *check.C) { +func (t *mysqlTestSuite) TearDownSuite(c *check.C) { } @@ -59,6 +60,12 @@ func (t *mysqlTestSuite) TestMysqlGTIDIntervalSlice(c *check.C) { n = i.Normalize() c.Assert(n, check.DeepEquals, IntervalSlice{Interval{1, 3}, Interval{4, 5}}) + i = IntervalSlice{Interval{1, 4}, Interval{2, 3}} + i.Sort() + c.Assert(i, check.DeepEquals, IntervalSlice{Interval{1, 4}, Interval{2, 3}}) + n = i.Normalize() + c.Assert(n, check.DeepEquals, IntervalSlice{Interval{1, 4}}) + n1 := IntervalSlice{Interval{1, 3}, Interval{4, 5}} n2 := IntervalSlice{Interval{1, 2}} @@ -91,6 +98,15 @@ func (t *mysqlTestSuite) TestMysqlGTIDCodec(c *check.C) { c.Assert(gs, check.DeepEquals, o) } +func (t *mysqlTestSuite) TestMysqlUpdate(c *check.C) { + g1, err := ParseMysqlGTIDSet("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-57") + c.Assert(err, check.IsNil) + + g1.Update("3E11FA47-71CA-11E1-9E33-C80AA9429562:21-58") + + c.Assert(strings.ToUpper(g1.String()), check.Equals, "3E11FA47-71CA-11E1-9E33-C80AA9429562:21-58") +} + func (t *mysqlTestSuite) TestMysqlGTIDContain(c *check.C) { g1, err := ParseMysqlGTIDSet("3E11FA47-71CA-11E1-9E33-C80AA9429562:23") c.Assert(err, check.IsNil) @@ -151,3 +167,26 @@ func (t *mysqlTestSuite) TestMysqlParseBinaryUint64(c *check.C) { u64 := ParseBinaryUint64([]byte{1, 2, 3, 4, 5, 6, 7, 128}) c.Assert(u64, check.Equals, 128*uint64(72057594037927936)+7*uint64(281474976710656)+6*uint64(1099511627776)+5*uint64(4294967296)+4*16777216+3*65536+2*256+1) } + +func (t *mysqlTestSuite) TestErrorCode(c *check.C) { + tbls := []struct { + msg string + code int + }{ + {"ERROR 1094 (HY000): Unknown thread id: 1094", 1094}, + {"error string", 0}, + {"abcdefg", 0}, + {"123455 ks094", 0}, + {"ERROR 1046 (3D000): Unknown error 1046", 1046}, + } + for _, v := range tbls { + c.Assert(ErrorCode(v.msg), check.Equals, v.code) + } +} + +func (t *mysqlTestSuite) TestMysqlNullDecode(c *check.C) { + _, isNull, n := LengthEncodedInt([]byte{0xfb}) + + c.Assert(isNull, check.IsTrue) + c.Assert(n, check.Equals, 1) +} diff --git a/vendor/github.com/siddontang/go-mysql/mysql/resultset.go b/vendor/github.com/siddontang/go-mysql/mysql/resultset.go index a50d9f8..b01e1a5 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/resultset.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/resultset.go @@ -28,7 +28,7 @@ func (p RowData) ParseText(f []*Field) ([]interface{}, error) { var n int = 0 for i := range f { - v, isNull, n, err = LengthEnodedString(p[pos:]) + v, isNull, n, err = LengthEncodedString(p[pos:]) if err != nil { return nil, errors.Trace(err) } @@ -115,7 +115,8 @@ func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) { } else { data[i] = ParseBinaryInt24(p[pos : pos+3]) } - pos += 4 + //3 byte + pos += 3 continue case MYSQL_TYPE_LONG: @@ -150,7 +151,7 @@ func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) { MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY: - v, isNull, n, err = LengthEnodedString(p[pos:]) + v, isNull, n, err = LengthEncodedString(p[pos:]) pos += n if err != nil { return nil, errors.Trace(err) diff --git a/vendor/github.com/siddontang/go-mysql/mysql/resultset_helper.go b/vendor/github.com/siddontang/go-mysql/mysql/resultset_helper.go index 488d253..307684d 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/resultset_helper.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/resultset_helper.go @@ -38,6 +38,8 @@ func formatTextValue(value interface{}) ([]byte, error) { return v, nil case string: return hack.Slice(v), nil + case nil: + return nil, nil default: return nil, errors.Errorf("invalid type %T", value) } @@ -77,23 +79,40 @@ func formatBinaryValue(value interface{}) ([]byte, error) { return nil, errors.Errorf("invalid type %T", value) } } + +func fieldType(value interface{}) (typ uint8, err error) { + switch value.(type) { + case int8, int16, int32, int64, int: + typ = MYSQL_TYPE_LONGLONG + case uint8, uint16, uint32, uint64, uint: + typ = MYSQL_TYPE_LONGLONG + case float32, float64: + typ = MYSQL_TYPE_DOUBLE + case string, []byte: + typ = MYSQL_TYPE_VAR_STRING + case nil: + typ = MYSQL_TYPE_NULL + default: + err = errors.Errorf("unsupport type %T for resultset", value) + } + return +} + func formatField(field *Field, value interface{}) error { switch value.(type) { case int8, int16, int32, int64, int: field.Charset = 63 - field.Type = MYSQL_TYPE_LONGLONG field.Flag = BINARY_FLAG | NOT_NULL_FLAG case uint8, uint16, uint32, uint64, uint: field.Charset = 63 - field.Type = MYSQL_TYPE_LONGLONG field.Flag = BINARY_FLAG | NOT_NULL_FLAG | UNSIGNED_FLAG case float32, float64: field.Charset = 63 - field.Type = MYSQL_TYPE_DOUBLE field.Flag = BINARY_FLAG | NOT_NULL_FLAG case string, []byte: field.Charset = 33 - field.Type = MYSQL_TYPE_VAR_STRING + case nil: + field.Charset = 33 default: return errors.Errorf("unsupport type %T for resultset", value) } @@ -106,7 +125,13 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse r.Fields = make([]*Field, len(names)) var b []byte - var err error + + if len(values) == 0 { + for i, name := range names { + r.Fields[i] = &Field{Name: hack.Slice(name), Charset: 33, Type: MYSQL_TYPE_NULL} + } + return r, nil + } for i, vs := range values { if len(vs) != len(r.Fields) { @@ -115,13 +140,23 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse var row []byte for j, value := range vs { - if i == 0 { - field := &Field{} - r.Fields[j] = field - field.Name = hack.Slice(names[j]) - - if err = formatField(field, value); err != nil { - return nil, errors.Trace(err) + typ, err := fieldType(value) + if err != nil { + return nil, errors.Trace(err) + } + if r.Fields[j] == nil { + r.Fields[j] = &Field{Name: hack.Slice(names[j]), Type: typ} + formatField(r.Fields[j], value) + } else if typ != r.Fields[j].Type { + // we got another type in the same column. in general, we treat it as an error, except + // the case, when old value was null, and the new one isn't null, so we can update + // type info for fields. + oldIsNull, newIsNull := r.Fields[j].Type == MYSQL_TYPE_NULL, typ == MYSQL_TYPE_NULL + if oldIsNull && !newIsNull { // old is null, new isn't, update type info. + r.Fields[j].Type = typ + formatField(r.Fields[j], value) + } else if !oldIsNull && !newIsNull { // different non-null types, that's an error. + return nil, errors.Errorf("row types aren't consistent") } } b, err = formatTextValue(value) @@ -130,7 +165,12 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse return nil, errors.Trace(err) } - row = append(row, PutLengthEncodedString(b)...) + if b == nil { + // NULL value is encoded as 0xfb here (without additional info about length) + row = append(row, 0xfb) + } else { + row = append(row, PutLengthEncodedString(b)...) + } } r.RowDatas = append(r.RowDatas, row) @@ -145,7 +185,6 @@ func BuildSimpleBinaryResultset(names []string, values [][]interface{}) (*Result r.Fields = make([]*Field, len(names)) var b []byte - var err error bitmapLen := ((len(names) + 7 + 2) >> 3) @@ -161,8 +200,12 @@ func BuildSimpleBinaryResultset(names []string, values [][]interface{}) (*Result row = append(row, nullBitmap...) for j, value := range vs { + typ, err := fieldType(value) + if err != nil { + return nil, errors.Trace(err) + } if i == 0 { - field := &Field{} + field := &Field{Type: typ} r.Fields[j] = field field.Name = hack.Slice(names[j]) diff --git a/vendor/github.com/siddontang/go-mysql/mysql/util.go b/vendor/github.com/siddontang/go-mysql/mysql/util.go index 7fe41fa..757910e 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/util.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/util.go @@ -11,6 +11,8 @@ import ( "github.com/juju/errors" "github.com/siddontang/go/hack" + "crypto/sha256" + "crypto/rsa" ) func Pstack() string { @@ -48,6 +50,62 @@ func CalcPassword(scramble, password []byte) []byte { return scramble } +// Hash password using MySQL 8+ method (SHA256) +func CalcCachingSha2Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + + +func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1v := sha1.New() + return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil) +} + +// encodes a uint64 value and appends it to the given bytes slice +func AppendLengthEncodedInteger(b []byte, n uint64) []byte { + switch { + case n <= 250: + return append(b, byte(n)) + + case n <= 0xffff: + return append(b, 0xfc, byte(n), byte(n>>8)) + + case n <= 0xffffff: + return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) + } + return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), + byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) +} + func RandomBuf(size int) ([]byte, error) { buf := make([]byte, size) @@ -84,39 +142,33 @@ func BFixedLengthInt(buf []byte) uint64 { } func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { - switch b[0] { + if len(b) == 0 { + return 0, true, 1 + } + switch b[0] { // 251: NULL case 0xfb: - n = 1 - isNull = true - return + return 0, true, 1 - // 252: value of following 2 + // 252: value of following 2 case 0xfc: - num = uint64(b[1]) | uint64(b[2])<<8 - n = 3 - return + return uint64(b[1]) | uint64(b[2])<<8, false, 3 - // 253: value of following 3 + // 253: value of following 3 case 0xfd: - num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 - n = 4 - return + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 - // 254: value of following 8 + // 254: value of following 8 case 0xfe: - num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | - uint64(b[7])<<48 | uint64(b[8])<<56 - n = 9 - return + uint64(b[7])<<48 | uint64(b[8])<<56, + false, 9 } // 0-250: value of first byte - num = uint64(b[0]) - n = 1 - return + return uint64(b[0]), false, 1 } func PutLengthEncodedInt(n uint64) []byte { @@ -137,23 +189,26 @@ func PutLengthEncodedInt(n uint64) []byte { return nil } -func LengthEnodedString(b []byte) ([]byte, bool, int, error) { +// returns the string read as a bytes slice, whether the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice +func LengthEncodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := LengthEncodedInt(b) if num < 1 { - return nil, isNull, n, nil + return b[n:n], isNull, n, nil } n += int(num) // Check data length if len(b) >= n { - return b[n-int(num) : n], false, n, nil + return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } -func SkipLengthEnodedString(b []byte) (int, error) { +func SkipLengthEncodedString(b []byte) (int, error) { // Get length num, _, n := LengthEncodedInt(b) if num < 1 { diff --git a/vendor/github.com/siddontang/go-mysql/packet/conn.go b/vendor/github.com/siddontang/go-mysql/packet/conn.go index 3772e1a..41b1bf1 100644 --- a/vendor/github.com/siddontang/go-mysql/packet/conn.go +++ b/vendor/github.com/siddontang/go-mysql/packet/conn.go @@ -1,11 +1,17 @@ package packet +import "C" import ( - "bufio" "bytes" "io" "net" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/pem" + "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" ) @@ -15,7 +21,9 @@ import ( */ type Conn struct { net.Conn - br *bufio.Reader + + // we removed the buffer reader because it will cause the SSLRequest to block (tls connection handshake won't be + // able to read the "Client Hello" data since it has been buffered into the buffer reader) Sequence uint8 } @@ -23,7 +31,6 @@ type Conn struct { func NewConn(conn net.Conn) *Conn { c := new(Conn) - c.br = bufio.NewReaderSize(conn, 4096) c.Conn = conn return c @@ -37,55 +44,20 @@ func (c *Conn) ReadPacket() ([]byte, error) { } else { return buf.Bytes(), nil } - - // header := []byte{0, 0, 0, 0} - - // if _, err := io.ReadFull(c.br, header); err != nil { - // return nil, ErrBadConn - // } - - // length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - // if length < 1 { - // return nil, fmt.Errorf("invalid payload length %d", length) - // } - - // sequence := uint8(header[3]) - - // if sequence != c.Sequence { - // return nil, fmt.Errorf("invalid sequence %d != %d", sequence, c.Sequence) - // } - - // c.Sequence++ - - // data := make([]byte, length) - // if _, err := io.ReadFull(c.br, data); err != nil { - // return nil, ErrBadConn - // } else { - // if length < MaxPayloadLen { - // return data, nil - // } - - // var buf []byte - // buf, err = c.ReadPacket() - // if err != nil { - // return nil, ErrBadConn - // } else { - // return append(data, buf...), nil - // } - // } } func (c *Conn) ReadPacketTo(w io.Writer) error { header := []byte{0, 0, 0, 0} - if _, err := io.ReadFull(c.br, header); err != nil { + if _, err := io.ReadFull(c.Conn, header); err != nil { return ErrBadConn } length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - if length < 1 { - return errors.Errorf("invalid payload length %d", length) - } + // bug fixed: caching_sha2_password will send 0-length payload (the unscrambled password) when the password is empty + //if length < 1 { + // return errors.Errorf("invalid payload length %d", length) + //} sequence := uint8(header[3]) @@ -95,7 +67,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { c.Sequence++ - if n, err := io.CopyN(w, c.br, int64(length)); err != nil { + if n, err := io.CopyN(w, c.Conn, int64(length)); err != nil { return ErrBadConn } else if n != int64(length) { return ErrBadConn @@ -150,6 +122,77 @@ func (c *Conn) WritePacket(data []byte) error { } } +// Client clear text authentication packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (c *Conn) WriteClearAuthPacket(password string) error { + // Calculate the packet length and add a tailing 0 + pktLen := len(password) + 1 + data := make([]byte, 4 + pktLen) + + // Add the clear password [null terminated string] + copy(data[4:], password) + data[4+pktLen-1] = 0x00 + + return c.WritePacket(data) +} + +// Caching sha2 authentication. Public key request and send encrypted password +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (c *Conn) WritePublicKeyAuthPacket(password string, cipher []byte) error { + // request public key + data := make([]byte, 4 + 1) + data[4] = 2 // cachingSha2PasswordRequestPublicKey + c.WritePacket(data) + + data, err := c.ReadPacket() + if err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(cipher) + plain[i] ^= cipher[j] + } + sha1v := sha1.New() + enc, _ := rsa.EncryptOAEP(sha1v, rand.Reader, pub.(*rsa.PublicKey), plain, nil) + data = make([]byte, 4 + len(enc)) + copy(data[4:], enc) + return c.WritePacket(data) +} + +func (c *Conn) WriteEncryptedPassword(password string, seed []byte, pub *rsa.PublicKey) error { + enc, err := EncryptPassword(password, seed, pub) + if err != nil { + return err + } + return c.WriteAuthSwitchPacket(enc, false) +} + +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (c *Conn) WriteAuthSwitchPacket(authData []byte, addNUL bool) error { + pktLen := 4 + len(authData) + if addNUL { + pktLen++ + } + data := make([]byte, pktLen) + + // Add the auth data [EOF] + copy(data[4:], authData) + if addNUL { + data[pktLen-1] = 0x00 + } + + return c.WritePacket(data) +} + func (c *Conn) ResetSequence() { c.Sequence = 0 } diff --git a/vendor/github.com/siddontang/go-mysql/replication/backup.go b/vendor/github.com/siddontang/go-mysql/replication/backup.go index 744c38c..24a25ae 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/backup.go +++ b/vendor/github.com/siddontang/go-mysql/replication/backup.go @@ -1,13 +1,12 @@ package replication import ( + "context" "io" "os" "path" "time" - "golang.org/x/net/context" - "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" ) @@ -41,7 +40,7 @@ func (b *BinlogSyncer) StartBackup(backupDir string, p Position, timeout time.Du }() for { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) e, err := s.GetEvent(ctx) cancel() diff --git a/vendor/github.com/siddontang/go-mysql/replication/backup_test.go b/vendor/github.com/siddontang/go-mysql/replication/backup_test.go new file mode 100644 index 0000000..5e39e7f --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/replication/backup_test.go @@ -0,0 +1,50 @@ +package replication + +import ( + "context" + "github.com/juju/errors" + . "github.com/pingcap/check" + "github.com/siddontang/go-mysql/mysql" + "os" + "sync" + "time" +) + +func (t *testSyncerSuite) TestStartBackupEndInGivenTime(c *C) { + t.setupTest(c, mysql.MySQLFlavor) + + t.testExecute(c, "RESET MASTER") + + var wg sync.WaitGroup + wg.Add(1) + defer wg.Wait() + + go func() { + defer wg.Done() + + t.testSync(c, nil) + + t.testExecute(c, "FLUSH LOGS") + + t.testSync(c, nil) + }() + + os.RemoveAll("./var") + timeout := 2 * time.Second + + done := make(chan bool) + + go func() { + err := t.b.StartBackup("./var", mysql.Position{Name: "", Pos: uint32(0)}, timeout) + c.Assert(err, IsNil) + done <- true + }() + failTimeout := 5 * timeout + ctx, _ := context.WithTimeout(context.Background(), failTimeout) + select { + case <-done: + return + case <-ctx.Done(): + c.Assert(errors.New("time out error"), IsNil) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go b/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go index e5b165c..c1e4057 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go +++ b/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go @@ -1,10 +1,10 @@ package replication import ( - "golang.org/x/net/context" - + "context" + "time" "github.com/juju/errors" - "github.com/ngaut/log" + "github.com/siddontang/go-log/log" ) var ( @@ -36,6 +36,36 @@ func (s *BinlogStreamer) GetEvent(ctx context.Context) (*BinlogEvent, error) { } } +// Get the binlog event with starttime, if current binlog event timestamp smaller than specify starttime +// return nil event +func (s *BinlogStreamer) GetEventWithStartTime(ctx context.Context,startTime time.Time) (*BinlogEvent, error) { + if s.err != nil { + return nil, ErrNeedSyncAgain + } + startUnix := startTime.Unix() + select { + case c := <-s.ch: + if int64(c.Header.Timestamp) >= startUnix { + return c, nil + } + return nil,nil + case s.err = <-s.ech: + return nil, s.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// DumpEvents dumps all left events +func (s *BinlogStreamer) DumpEvents() []*BinlogEvent { + count := len(s.ch) + events := make([]*BinlogEvent, 0, count) + for i := 0; i < count; i++ { + events = append(events, <-s.ch) + } + return events +} + func (s *BinlogStreamer) close() { s.closeWithError(ErrSyncClosed) } diff --git a/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go b/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go index 90bf691..552798b 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go +++ b/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go @@ -1,17 +1,19 @@ package replication import ( + "context" "crypto/tls" "encoding/binary" "fmt" + "net" "os" + "strings" "sync" "time" - "golang.org/x/net/context" - "github.com/juju/errors" - "github.com/ngaut/log" + "github.com/satori/go.uuid" + "github.com/siddontang/go-log/log" "github.com/siddontang/go-mysql/client" . "github.com/siddontang/go-mysql/mysql" ) @@ -40,21 +42,64 @@ type BinlogSyncerConfig struct { // If not set, use os.Hostname() instead. Localhost string + // Charset is for MySQL client character set + Charset string + // SemiSyncEnabled enables semi-sync or not. SemiSyncEnabled bool - // RawModeEanbled is for not parsing binlog event. - RawModeEanbled bool + // RawModeEnabled is for not parsing binlog event. + RawModeEnabled bool // If not nil, use the provided tls.Config to connect to the database using TLS/SSL. TLSConfig *tls.Config + + // Use replication.Time structure for timestamp and datetime. + // We will use Local location for timestamp and UTC location for datatime. + ParseTime bool + + // If ParseTime is false, convert TIMESTAMP into this specified timezone. If + // ParseTime is true, this option will have no effect and TIMESTAMP data will + // be parsed into the local timezone and a full time.Time struct will be + // returned. + // + // Note that MySQL TIMESTAMP columns are offset from the machine local + // timezone while DATETIME columns are offset from UTC. This is consistent + // with documented MySQL behaviour as it return TIMESTAMP in local timezone + // and DATETIME in UTC. + // + // Setting this to UTC effectively equalizes the TIMESTAMP and DATETIME time + // strings obtained from MySQL. + TimestampStringLocation *time.Location + + // Use decimal.Decimal structure for decimals. + UseDecimal bool + + // RecvBufferSize sets the size in bytes of the operating system's receive buffer associated with the connection. + RecvBufferSize int + + // master heartbeat period + HeartbeatPeriod time.Duration + + // read timeout + ReadTimeout time.Duration + + // maximum number of attempts to re-establish a broken connection + MaxReconnectAttempts int + + // Only works when MySQL/MariaDB variable binlog_checksum=CRC32. + // For MySQL, binlog_checksum was introduced since 5.6.2, but CRC32 was set as default value since 5.6.6 . + // https://dev.mysql.com/doc/refman/5.6/en/replication-options-binary-log.html#option_mysqld_binlog-checksum + // For MariaDB, binlog_checksum was introduced since MariaDB 5.3, but CRC32 was set as default value since MariaDB 10.2.1 . + // https://mariadb.com/kb/en/library/replication-and-binary-log-server-system-variables/#binlog_checksum + VerifyChecksum bool } // BinlogSyncer syncs binlog event from server. type BinlogSyncer struct { m sync.RWMutex - cfg *BinlogSyncerConfig + cfg BinlogSyncerConfig c *client.Conn @@ -64,22 +109,39 @@ type BinlogSyncer struct { nextPos Position + gset GTIDSet + running bool ctx context.Context cancel context.CancelFunc + + lastConnectionID uint32 + + retryCount int } // NewBinlogSyncer creates the BinlogSyncer with cfg. -func NewBinlogSyncer(cfg *BinlogSyncerConfig) *BinlogSyncer { +func NewBinlogSyncer(cfg BinlogSyncerConfig) *BinlogSyncer { + if cfg.ServerID == 0 { + log.Fatal("can't use 0 as the server ID") + } + + // Clear the Password to avoid outputing it in log. + pass := cfg.Password + cfg.Password = "" log.Infof("create BinlogSyncer with config %v", cfg) + cfg.Password = pass b := new(BinlogSyncer) b.cfg = cfg b.parser = NewBinlogParser() - b.parser.SetRawMode(b.cfg.RawModeEanbled) - + b.parser.SetRawMode(b.cfg.RawModeEnabled) + b.parser.SetParseTime(b.cfg.ParseTime) + b.parser.SetTimestampStringLocation(b.cfg.TimestampStringLocation) + b.parser.SetUseDecimal(b.cfg.UseDecimal) + b.parser.SetVerifyChecksum(b.cfg.VerifyChecksum) b.running = false b.ctx, b.cancel = context.WithCancel(context.Background()) @@ -131,15 +193,53 @@ func (b *BinlogSyncer) registerSlave() error { b.c.Close() } - log.Infof("register slave for master server %s:%d", b.cfg.Host, b.cfg.Port) + addr := "" + if strings.Contains(b.cfg.Host, "/") { + addr = b.cfg.Host + } else { + addr = fmt.Sprintf("%s:%d", b.cfg.Host, b.cfg.Port) + } + + log.Infof("register slave for master server %s", addr) var err error - b.c, err = client.Connect(fmt.Sprintf("%s:%d", b.cfg.Host, b.cfg.Port), b.cfg.User, b.cfg.Password, "", func(c *client.Conn) { - c.TLSConfig = b.cfg.TLSConfig + b.c, err = client.Connect(addr, b.cfg.User, b.cfg.Password, "", func(c *client.Conn) { + c.SetTLSConfig(b.cfg.TLSConfig) }) if err != nil { return errors.Trace(err) } + if len(b.cfg.Charset) != 0 { + b.c.SetCharset(b.cfg.Charset) + } + + //set read timeout + if b.cfg.ReadTimeout > 0 { + b.c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) + } + + if b.cfg.RecvBufferSize > 0 { + if tcp, ok := b.c.Conn.Conn.(*net.TCPConn); ok { + tcp.SetReadBuffer(b.cfg.RecvBufferSize) + } + } + + // kill last connection id + if b.lastConnectionID > 0 { + cmd := fmt.Sprintf("KILL %d", b.lastConnectionID) + if _, err := b.c.Execute(cmd); err != nil { + log.Errorf("kill connection %d error %v", b.lastConnectionID, err) + // Unknown thread id + if code := ErrorCode(err.Error()); code != ER_NO_SUCH_THREAD { + return errors.Trace(err) + } + } + log.Infof("kill last connection id %d", b.lastConnectionID) + } + + // save last last connection id for kill + b.lastConnectionID = b.c.GetConnectionID() + //for mysql 5.6+, binlog has a crc32 checksum //before mysql 5.6, this will not work, don't matter.:-) if r, err := b.c.Execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'"); err != nil { @@ -175,6 +275,14 @@ func (b *BinlogSyncer) registerSlave() error { } } + if b.cfg.HeartbeatPeriod > 0 { + _, err = b.c.Execute(fmt.Sprintf("SET @master_heartbeat_period=%d;", b.cfg.HeartbeatPeriod)) + if err != nil { + log.Errorf("failed to set @master_heartbeat_period=%d, err: %v", b.cfg.HeartbeatPeriod, err) + return errors.Trace(err) + } + } + if err = b.writeRegisterSlaveCommand(); err != nil { return errors.Trace(err) } @@ -186,7 +294,7 @@ func (b *BinlogSyncer) registerSlave() error { return nil } -func (b *BinlogSyncer) enalbeSemiSync() error { +func (b *BinlogSyncer) enableSemiSync() error { if !b.cfg.SemiSyncEnabled { return nil } @@ -219,7 +327,7 @@ func (b *BinlogSyncer) prepare() error { return errors.Trace(err) } - if err := b.enalbeSemiSync(); err != nil { + if err := b.enableSemiSync(); err != nil { return errors.Trace(err) } @@ -236,6 +344,11 @@ func (b *BinlogSyncer) startDumpStream() *BinlogStreamer { return s } +// GetNextPosition returns the next position of the syncer +func (b *BinlogSyncer) GetNextPosition() Position { + return b.nextPos +} + // StartSync starts syncing from the `pos` position. func (b *BinlogSyncer) StartSync(pos Position) (*BinlogStreamer, error) { log.Infof("begin to sync binlog from position %s", pos) @@ -256,7 +369,9 @@ func (b *BinlogSyncer) StartSync(pos Position) (*BinlogStreamer, error) { // StartSyncGTID starts syncing from the `gset` GTIDSet. func (b *BinlogSyncer) StartSyncGTID(gset GTIDSet) (*BinlogStreamer, error) { - log.Infof("begin to sync binlog from GTID %s", gset) + log.Infof("begin to sync binlog from GTID set %s", gset) + + b.gset = gset b.m.Lock() defer b.m.Unlock() @@ -270,11 +385,12 @@ func (b *BinlogSyncer) StartSyncGTID(gset GTIDSet) (*BinlogStreamer, error) { } var err error - if b.cfg.Flavor != MariaDBFlavor { + switch b.cfg.Flavor { + case MariaDBFlavor: + err = b.writeBinlogDumpMariadbGTIDCommand(gset) + default: // default use MySQL err = b.writeBinlogDumpMysqlGTIDCommand(gset) - } else { - err = b.writeBinlogDumpMariadbGTIDCommand(gset) } if err != nil { @@ -284,7 +400,7 @@ func (b *BinlogSyncer) StartSyncGTID(gset GTIDSet) (*BinlogStreamer, error) { return b.startDumpStream(), nil } -func (b *BinlogSyncer) writeBinglogDumpCommand(p Position) error { +func (b *BinlogSyncer) writeBinlogDumpCommand(p Position) error { b.c.ResetSequence() data := make([]byte, 4+1+4+2+4+len(p.Name)) @@ -308,7 +424,7 @@ func (b *BinlogSyncer) writeBinglogDumpCommand(p Position) error { } func (b *BinlogSyncer) writeBinlogDumpMysqlGTIDCommand(gset GTIDSet) error { - p := Position{"", 4} + p := Position{Name: "", Pos: 4} gtidData := gset.Encode() b.c.ResetSequence() @@ -364,7 +480,7 @@ func (b *BinlogSyncer) writeBinlogDumpMariadbGTIDCommand(gset GTIDSet) error { } // Since we use @slave_connect_state, the file and position here are ignored. - return b.writeBinglogDumpCommand(Position{"", 0}) + return b.writeBinlogDumpCommand(Position{Name: "", Pos: 0}) } // localHostname returns the hostname that register slave would register as. @@ -439,21 +555,25 @@ func (b *BinlogSyncer) replySemiSyncACK(p Position) error { return errors.Trace(err) } - _, err = b.c.ReadOKPacket() - if err != nil { - } - return errors.Trace(err) + return nil } func (b *BinlogSyncer) retrySync() error { b.m.Lock() defer b.m.Unlock() - log.Infof("begin to re-sync from %s", b.nextPos) - b.parser.Reset() - if err := b.prepareSyncPos(b.nextPos); err != nil { - return errors.Trace(err) + + if b.gset != nil { + log.Infof("begin to re-sync from %s", b.gset.String()) + if err := b.prepareSyncGTID(b.gset); err != nil { + return errors.Trace(err) + } + } else { + log.Infof("begin to re-sync from %s", b.nextPos) + if err := b.prepareSyncPos(b.nextPos); err != nil { + return errors.Trace(err) + } } return nil @@ -469,13 +589,34 @@ func (b *BinlogSyncer) prepareSyncPos(pos Position) error { return errors.Trace(err) } - if err := b.writeBinglogDumpCommand(pos); err != nil { + if err := b.writeBinlogDumpCommand(pos); err != nil { return errors.Trace(err) } return nil } +func (b *BinlogSyncer) prepareSyncGTID(gset GTIDSet) error { + var err error + + if err = b.prepare(); err != nil { + return errors.Trace(err) + } + + switch b.cfg.Flavor { + case MariaDBFlavor: + err = b.writeBinlogDumpMariadbGTIDCommand(gset) + default: + // default use MySQL + err = b.writeBinlogDumpMysqlGTIDCommand(gset) + } + + if err != nil { + return err + } + return nil +} + func (b *BinlogSyncer) onStream(s *BinlogStreamer) { defer func() { if e := recover(); e != nil { @@ -490,21 +631,27 @@ func (b *BinlogSyncer) onStream(s *BinlogStreamer) { log.Error(err) // we meet connection error, should re-connect again with - // last nextPos we got. - if len(b.nextPos.Name) == 0 { + // last nextPos or nextGTID we got. + if len(b.nextPos.Name) == 0 && b.gset == nil { // we can't get the correct position, close. s.closeWithError(err) return } - // TODO: add a max retry count. for { select { case <-b.ctx.Done(): s.close() return case <-time.After(time.Second): + b.retryCount++ if err = b.retrySync(); err != nil { + if b.cfg.MaxReconnectAttempts > 0 && b.retryCount >= b.cfg.MaxReconnectAttempts { + log.Errorf("retry sync err: %v, exceeded max retries (%d)", err, b.cfg.MaxReconnectAttempts) + s.closeWithError(err) + return + } + log.Errorf("retry sync err: %v, wait 1s and retry again", err) continue } @@ -517,6 +664,14 @@ func (b *BinlogSyncer) onStream(s *BinlogStreamer) { continue } + //set read timeout + if b.cfg.ReadTimeout > 0 { + b.c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) + } + + // Reset retry count on successful packet receieve + b.retryCount = 0 + switch data[0] { case OK_HEADER: if err = b.parseEvent(s, data); err != nil { @@ -552,7 +707,7 @@ func (b *BinlogSyncer) parseEvent(s *BinlogStreamer, data []byte) error { data = data[2:] } - e, err := b.parser.parse(data) + e, err := b.parser.Parse(data) if err != nil { return errors.Trace(err) } @@ -561,11 +716,33 @@ func (b *BinlogSyncer) parseEvent(s *BinlogStreamer, data []byte) error { // Some events like FormatDescriptionEvent return 0, ignore. b.nextPos.Pos = e.Header.LogPos } - - if re, ok := e.Event.(*RotateEvent); ok { - b.nextPos.Name = string(re.NextLogName) - b.nextPos.Pos = uint32(re.Position) + switch event := e.Event.(type) { + case *RotateEvent: + b.nextPos.Name = string(event.NextLogName) + b.nextPos.Pos = uint32(event.Position) log.Infof("rotate to %s", b.nextPos) + case *GTIDEvent: + if b.gset == nil { + break + } + u, _ := uuid.FromBytes(event.SID) + err := b.gset.Update(fmt.Sprintf("%s:%d", u.String(), event.GNO)) + if err != nil { + return errors.Trace(err) + } + case *MariadbGTIDEvent: + if b.gset == nil { + break + } + GTID := event.GTID + err := b.gset.Update(fmt.Sprintf("%d-%d-%d", GTID.DomainID, GTID.ServerID, GTID.SequenceNumber)) + if err != nil { + return errors.Trace(err) + } + case *XIDEvent: + event.GSet = b.getGtidSet() + case *QueryEvent: + event.GSet = b.getGtidSet() } needStop := false @@ -588,3 +765,15 @@ func (b *BinlogSyncer) parseEvent(s *BinlogStreamer, data []byte) error { return nil } + +func (b *BinlogSyncer) getGtidSet() GTIDSet { + if b.gset == nil { + return nil + } + return b.gset.Clone() +} + +// LastConnectionID returns last connectionID. +func (b *BinlogSyncer) LastConnectionID() uint32 { + return b.lastConnectionID +} diff --git a/vendor/github.com/siddontang/go-mysql/replication/event.go b/vendor/github.com/siddontang/go-mysql/replication/event.go index 1a0d2c6..737b431 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/event.go +++ b/vendor/github.com/siddontang/go-mysql/replication/event.go @@ -2,7 +2,6 @@ package replication import ( "encoding/binary" - //"encoding/hex" "fmt" "io" "strconv" @@ -16,11 +15,15 @@ import ( ) const ( - EventHeaderSize = 19 + EventHeaderSize = 19 + SidLength = 16 + LogicalTimestampTypeCode = 2 + PartLogicalTimestampLength = 8 + BinlogChecksumLength = 4 ) type BinlogEvent struct { - // raw binlog data, including crc32 checksum if exists + // raw binlog data which contains all data, including binlog header and event body, and including crc32 checksum if exists RawData []byte Header *EventHeader @@ -50,7 +53,7 @@ type EventError struct { } func (e *EventError) Error() string { - return e.Err + return fmt.Sprintf("Header %#v, Data %q, Err: %v", e.Header, e.Data, e.Err) } type EventHeader struct { @@ -216,6 +219,9 @@ func (e *RotateEvent) Dump(w io.Writer) { type XIDEvent struct { XID uint64 + + // in fact XIDEvent dosen't have the GTIDSet information, just for beneficial to use + GSet GTIDSet } func (e *XIDEvent) Decode(data []byte) error { @@ -225,6 +231,9 @@ func (e *XIDEvent) Decode(data []byte) error { func (e *XIDEvent) Dump(w io.Writer) { fmt.Fprintf(w, "XID: %d\n", e.XID) + if e.GSet != nil { + fmt.Fprintf(w, "GTIDSet: %s\n", e.GSet.String()) + } fmt.Fprintln(w) } @@ -235,6 +244,9 @@ type QueryEvent struct { StatusVars []byte Schema []byte Query []byte + + // in fact QueryEvent dosen't have the GTIDSet information, just for beneficial to use + GSet GTIDSet } func (e *QueryEvent) Decode(data []byte) error { @@ -275,21 +287,36 @@ func (e *QueryEvent) Dump(w io.Writer) { //fmt.Fprintf(w, "Status vars: \n%s", hex.Dump(e.StatusVars)) fmt.Fprintf(w, "Schema: %s\n", e.Schema) fmt.Fprintf(w, "Query: %s\n", e.Query) + if e.GSet != nil { + fmt.Fprintf(w, "GTIDSet: %s\n", e.GSet.String()) + } fmt.Fprintln(w) } type GTIDEvent struct { - CommitFlag uint8 - SID []byte - GNO int64 + CommitFlag uint8 + SID []byte + GNO int64 + LastCommitted int64 + SequenceNumber int64 } func (e *GTIDEvent) Decode(data []byte) error { - e.CommitFlag = uint8(data[0]) - - e.SID = data[1:17] - - e.GNO = int64(binary.LittleEndian.Uint64(data[17:])) + pos := 0 + e.CommitFlag = uint8(data[pos]) + pos++ + e.SID = data[pos : pos+SidLength] + pos += SidLength + e.GNO = int64(binary.LittleEndian.Uint64(data[pos:])) + pos += 8 + if len(data) >= 42 { + if uint8(data[pos]) == LogicalTimestampTypeCode { + pos++ + e.LastCommitted = int64(binary.LittleEndian.Uint64(data[pos:])) + pos += PartLogicalTimestampLength + e.SequenceNumber = int64(binary.LittleEndian.Uint64(data[pos:])) + } + } return nil } @@ -297,6 +324,8 @@ func (e *GTIDEvent) Dump(w io.Writer) { fmt.Fprintf(w, "Commit flag: %d\n", e.CommitFlag) u, _ := uuid.FromBytes(e.SID) fmt.Fprintf(w, "GTID_NEXT: %s:%d\n", u.String(), e.GNO) + fmt.Fprintf(w, "LAST_COMMITTED: %d\n", e.LastCommitted) + fmt.Fprintf(w, "SEQUENCE_NUMBER: %d\n", e.SequenceNumber) fmt.Fprintln(w) } @@ -382,16 +411,16 @@ func (e *ExecuteLoadQueryEvent) Dump(w io.Writer) { // case MARIADB_ANNOTATE_ROWS_EVENT: // return "MariadbAnnotateRowsEvent" -type MariadbAnnotaeRowsEvent struct { +type MariadbAnnotateRowsEvent struct { Query []byte } -func (e *MariadbAnnotaeRowsEvent) Decode(data []byte) error { +func (e *MariadbAnnotateRowsEvent) Decode(data []byte) error { e.Query = data return nil } -func (e *MariadbAnnotaeRowsEvent) Dump(w io.Writer) { +func (e *MariadbAnnotateRowsEvent) Dump(w io.Writer) { fmt.Fprintf(w, "Query: %s\n", e.Query) fmt.Fprintln(w) } @@ -424,7 +453,7 @@ func (e *MariadbGTIDEvent) Decode(data []byte) error { } func (e *MariadbGTIDEvent) Dump(w io.Writer) { - fmt.Fprintf(w, "GTID: %s\n", e.GTID) + fmt.Fprintf(w, "GTID: %v\n", e.GTID) fmt.Fprintln(w) } diff --git a/vendor/github.com/siddontang/go-mysql/replication/json_binary.go b/vendor/github.com/siddontang/go-mysql/replication/json_binary.go index 83b69d0..6529f01 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/json_binary.go +++ b/vendor/github.com/siddontang/go-mysql/replication/json_binary.go @@ -70,8 +70,13 @@ func jsonbGetValueEntrySize(isSmall bool) int { // decodeJsonBinary decodes the JSON binary encoding data and returns // the common JSON encoding data. -func decodeJsonBinary(data []byte) ([]byte, error) { - d := new(jsonBinaryDecoder) +func (e *RowsEvent) decodeJsonBinary(data []byte) ([]byte, error) { + // Sometimes, we can insert a NULL JSON even we set the JSON field as NOT NULL. + // If we meet this case, we can return an empty slice. + if len(data) == 0 { + return []byte{}, nil + } + d := jsonBinaryDecoder{useDecimal: e.useDecimal} if d.isDataShort(data, 1) { return nil, d.err @@ -86,7 +91,8 @@ func decodeJsonBinary(data []byte) ([]byte, error) { } type jsonBinaryDecoder struct { - err error + useDecimal bool + err error } func (d *jsonBinaryDecoder) decodeValue(tp byte, data []byte) interface{} { @@ -338,7 +344,7 @@ func (d *jsonBinaryDecoder) decodeString(data []byte) string { l, n := d.decodeVariableLength(data) - if d.isDataShort(data, int(l)+n) { + if d.isDataShort(data, l+n) { return "" } @@ -358,11 +364,11 @@ func (d *jsonBinaryDecoder) decodeOpaque(data []byte) interface{} { l, n := d.decodeVariableLength(data) - if d.isDataShort(data, int(l)+n) { + if d.isDataShort(data, l+n) { return nil } - data = data[n : int(l)+n] + data = data[n : l+n] switch tp { case MYSQL_TYPE_NEWDECIMAL: @@ -382,7 +388,7 @@ func (d *jsonBinaryDecoder) decodeDecimal(data []byte) interface{} { precision := int(data[0]) scale := int(data[1]) - v, _, err := decodeDecimal(data[2:], precision, scale) + v, _, err := decodeDecimal(data[2:], precision, scale, d.useDecimal) d.err = err return v @@ -459,11 +465,11 @@ func (d *jsonBinaryDecoder) decodeVariableLength(data []byte) (int, int) { length := uint64(0) for ; pos < maxCount; pos++ { v := data[pos] - length = (length << 7) + uint64(v&0x7F) + length |= uint64(v&0x7F) << uint(7*pos) if v&0x80 == 0 { if length > math.MaxUint32 { - d.err = errors.Errorf("variable length %d must <= %d", length, math.MaxUint32) + d.err = errors.Errorf("variable length %d must <= %d", length, int64(math.MaxUint32)) return 0, 0 } diff --git a/vendor/github.com/siddontang/go-mysql/replication/parser.go b/vendor/github.com/siddontang/go-mysql/replication/parser.go index cfc97e4..6fe1cc0 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/parser.go +++ b/vendor/github.com/siddontang/go-mysql/replication/parser.go @@ -2,13 +2,22 @@ package replication import ( "bytes" + "encoding/binary" "fmt" + "hash/crc32" "io" "os" + "sync/atomic" + "time" "github.com/juju/errors" ) +var ( + // ErrChecksumMismatch indicates binlog checksum mismatch. + ErrChecksumMismatch = errors.New("binlog checksum mismatch, data may be corrupted") +) + type BinlogParser struct { format *FormatDescriptionEvent @@ -16,6 +25,15 @@ type BinlogParser struct { // for rawMode, we only parse FormatDescriptionEvent and RotateEvent rawMode bool + + parseTime bool + timestampStringLocation *time.Location + + // used to start/stop processing + stopProcessing uint32 + + useDecimal bool + verifyChecksum bool } func NewBinlogParser() *BinlogParser { @@ -26,6 +44,14 @@ func NewBinlogParser() *BinlogParser { return p } +func (p *BinlogParser) Stop() { + atomic.StoreUint32(&p.stopProcessing, 1) +} + +func (p *BinlogParser) Resume() { + atomic.StoreUint32(&p.stopProcessing, 0) +} + func (p *BinlogParser) Reset() { p.format = nil } @@ -48,64 +74,102 @@ func (p *BinlogParser) ParseFile(name string, offset int64, onEvent OnEventFunc) if offset < 4 { offset = 4 + } else if offset > 4 { + // FORMAT_DESCRIPTION event should be read by default always (despite that fact passed offset may be higher than 4) + if _, err = f.Seek(4, os.SEEK_SET); err != nil { + return errors.Errorf("seek %s to %d error %v", name, offset, err) + } + + if err = p.parseFormatDescriptionEvent(f, onEvent); err != nil { + return errors.Annotatef(err, "parse FormatDescriptionEvent") + } } if _, err = f.Seek(offset, os.SEEK_SET); err != nil { return errors.Errorf("seek %s to %d error %v", name, offset, err) } - return p.parseReader(f, onEvent) + return p.ParseReader(f, onEvent) } -func (p *BinlogParser) parseReader(r io.Reader, onEvent OnEventFunc) error { - p.Reset() +func (p *BinlogParser) parseFormatDescriptionEvent(r io.Reader, onEvent OnEventFunc) error { + _, err := p.parseSingleEvent(r, onEvent) + return err +} +// ParseSingleEvent parses single binlog event and passes the event to onEvent function. +func (p *BinlogParser) ParseSingleEvent(r io.Reader, onEvent OnEventFunc) (bool, error) { + return p.parseSingleEvent(r, onEvent) +} + +func (p *BinlogParser) parseSingleEvent(r io.Reader, onEvent OnEventFunc) (bool, error) { var err error var n int64 + var buf bytes.Buffer + if n, err = io.CopyN(&buf, r, EventHeaderSize); err == io.EOF { + return true, nil + } else if err != nil { + return false, errors.Errorf("get event header err %v, need %d but got %d", err, EventHeaderSize, n) + } + + var h *EventHeader + h, err = p.parseHeader(buf.Bytes()) + if err != nil { + return false, errors.Trace(err) + } + + if h.EventSize <= uint32(EventHeaderSize) { + return false, errors.Errorf("invalid event header, event size is %d, too small", h.EventSize) + } + if n, err = io.CopyN(&buf, r, int64(h.EventSize-EventHeaderSize)); err != nil { + return false, errors.Errorf("get event err %v, need %d but got %d", err, h.EventSize, n) + } + if buf.Len() != int(h.EventSize) { + return false, errors.Errorf("invalid raw data size in event %s, need %d but got %d", h.EventType, h.EventSize, buf.Len()) + } + + rawData := buf.Bytes() + bodyLen := int(h.EventSize) - EventHeaderSize + body := rawData[EventHeaderSize:] + if len(body) != bodyLen { + return false, errors.Errorf("invalid body data size in event %s, need %d but got %d", h.EventType, bodyLen, len(body)) + } + + var e Event + e, err = p.parseEvent(h, body, rawData) + if err != nil { + if err == errMissingTableMapEvent { + return false, nil + } + return false, errors.Trace(err) + } + + if err = onEvent(&BinlogEvent{RawData: rawData, Header: h, Event: e}); err != nil { + return false, errors.Trace(err) + } + + return false, nil +} + +func (p *BinlogParser) ParseReader(r io.Reader, onEvent OnEventFunc) error { + for { - headBuf := make([]byte, EventHeaderSize) - - if _, err = io.ReadFull(r, headBuf); err == io.EOF { - return nil - } else if err != nil { - return errors.Trace(err) - } - - var h *EventHeader - h, err = p.parseHeader(headBuf) - if err != nil { - return errors.Trace(err) - } - - if h.EventSize <= uint32(EventHeaderSize) { - return errors.Errorf("invalid event header, event size is %d, too small", h.EventSize) - - } - - var buf bytes.Buffer - if n, err = io.CopyN(&buf, r, int64(h.EventSize)-int64(EventHeaderSize)); err != nil { - return errors.Errorf("get event body err %v, need %d - %d, but got %d", err, h.EventSize, EventHeaderSize, n) - } - - data := buf.Bytes() - rawData := data - - eventLen := int(h.EventSize) - EventHeaderSize - - if len(data) != eventLen { - return errors.Errorf("invalid data size %d in event %s, less event length %d", len(data), h.EventType, eventLen) - } - - var e Event - e, err = p.parseEvent(h, data) - if err != nil { + if atomic.LoadUint32(&p.stopProcessing) == 1 { break } - if err = onEvent(&BinlogEvent{rawData, h, e}); err != nil { + done, err := p.parseSingleEvent(r, onEvent) + if err != nil { + if err == errMissingTableMapEvent { + continue + } return errors.Trace(err) } + + if done { + break + } } return nil @@ -115,6 +179,22 @@ func (p *BinlogParser) SetRawMode(mode bool) { p.rawMode = mode } +func (p *BinlogParser) SetParseTime(parseTime bool) { + p.parseTime = parseTime +} + +func (p *BinlogParser) SetTimestampStringLocation(timestampStringLocation *time.Location) { + p.timestampStringLocation = timestampStringLocation +} + +func (p *BinlogParser) SetUseDecimal(useDecimal bool) { + p.useDecimal = useDecimal +} + +func (p *BinlogParser) SetVerifyChecksum(verify bool) { + p.verifyChecksum = verify +} + func (p *BinlogParser) parseHeader(data []byte) (*EventHeader, error) { h := new(EventHeader) err := h.Decode(data) @@ -125,7 +205,7 @@ func (p *BinlogParser) parseHeader(data []byte) (*EventHeader, error) { return h, nil } -func (p *BinlogParser) parseEvent(h *EventHeader, data []byte) (Event, error) { +func (p *BinlogParser) parseEvent(h *EventHeader, data []byte, rawData []byte) (Event, error) { var e Event if h.EventType == FORMAT_DESCRIPTION_EVENT { @@ -133,7 +213,11 @@ func (p *BinlogParser) parseEvent(h *EventHeader, data []byte) (Event, error) { e = p.format } else { if p.format != nil && p.format.ChecksumAlgorithm == BINLOG_CHECKSUM_ALG_CRC32 { - data = data[0 : len(data)-4] + err := p.verifyCrc32Checksum(rawData) + if err != nil { + return nil, err + } + data = data[0 : len(data)-BinlogChecksumLength] } if h.EventType == ROTATE_EVENT { @@ -166,12 +250,14 @@ func (p *BinlogParser) parseEvent(h *EventHeader, data []byte) (Event, error) { e = &RowsQueryEvent{} case GTID_EVENT: e = >IDEvent{} + case ANONYMOUS_GTID_EVENT: + e = >IDEvent{} case BEGIN_LOAD_QUERY_EVENT: e = &BeginLoadQueryEvent{} case EXECUTE_LOAD_QUERY_EVENT: e = &ExecuteLoadQueryEvent{} case MARIADB_ANNOTATE_ROWS_EVENT: - e = &MariadbAnnotaeRowsEvent{} + e = &MariadbAnnotateRowsEvent{} case MARIADB_BINLOG_CHECKPOINT_EVENT: e = &MariadbBinlogCheckPointEvent{} case MARIADB_GTID_LIST_EVENT: @@ -206,7 +292,13 @@ func (p *BinlogParser) parseEvent(h *EventHeader, data []byte) (Event, error) { return e, nil } -func (p *BinlogParser) parse(data []byte) (*BinlogEvent, error) { +// Given the bytes for a a binary log event: return the decoded event. +// With the exception of the FORMAT_DESCRIPTION_EVENT event type +// there must have previously been passed a FORMAT_DESCRIPTION_EVENT +// into the parser for this to work properly on any given event. +// Passing a new FORMAT_DESCRIPTION_EVENT into the parser will replace +// an existing one. +func (p *BinlogParser) Parse(data []byte) (*BinlogEvent, error) { rawData := data h, err := p.parseHeader(data) @@ -222,12 +314,32 @@ func (p *BinlogParser) parse(data []byte) (*BinlogEvent, error) { return nil, fmt.Errorf("invalid data size %d in event %s, less event length %d", len(data), h.EventType, eventLen) } - e, err := p.parseEvent(h, data) + e, err := p.parseEvent(h, data, rawData) if err != nil { return nil, err } - return &BinlogEvent{rawData, h, e}, nil + return &BinlogEvent{RawData: rawData, Header: h, Event: e}, nil +} + +func (p *BinlogParser) verifyCrc32Checksum(rawData []byte) error { + if !p.verifyChecksum { + return nil + } + + calculatedPart := rawData[0 : len(rawData)-BinlogChecksumLength] + expectedChecksum := rawData[len(rawData)-BinlogChecksumLength:] + + // mysql use zlib's CRC32 implementation, which uses polynomial 0xedb88320UL. + // reference: https://github.com/madler/zlib/blob/master/crc32.c + // https://github.com/madler/zlib/blob/master/doc/rfc1952.txt#L419 + checksum := crc32.ChecksumIEEE(calculatedPart) + computed := make([]byte, BinlogChecksumLength) + binary.LittleEndian.PutUint32(computed, checksum) + if !bytes.Equal(expectedChecksum, computed) { + return ErrChecksumMismatch + } + return nil } func (p *BinlogParser) newRowsEvent(h *EventHeader) *RowsEvent { @@ -240,6 +352,9 @@ func (p *BinlogParser) newRowsEvent(h *EventHeader) *RowsEvent { e.needBitmap2 = false e.tables = p.tables + e.parseTime = p.parseTime + e.timestampStringLocation = p.timestampStringLocation + e.useDecimal = p.useDecimal switch h.EventType { case WRITE_ROWS_EVENTv0: diff --git a/vendor/github.com/siddontang/go-mysql/replication/parser_test.go b/vendor/github.com/siddontang/go-mysql/replication/parser_test.go index 52f5f1f..d4efc98 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/parser_test.go +++ b/vendor/github.com/siddontang/go-mysql/replication/parser_test.go @@ -29,7 +29,7 @@ func (t *testSyncerSuite) TestIndexOutOfRange(c *C) { 0x3065f: &TableMapEvent{tableIDSize: 6, TableID: 0x3065f, Flags: 0x1, Schema: []uint8{0x73, 0x65, 0x69, 0x75, 0x6d, 0x61, 0x73, 0x74, 0x65, 0x72}, Table: []uint8{0x63, 0x6f, 0x6e, 0x73, 0x5f, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x73, 0x70, 0x65, 0x61, 0x6b, 0x6f, 0x75, 0x74, 0x5f, 0x6c, 0x65, 0x74, 0x74, 0x65, 0x72}, ColumnCount: 0xd, ColumnType: []uint8{0x3, 0x3, 0x3, 0x3, 0x1, 0x12, 0xf, 0xf, 0x12, 0xf, 0xf, 0x3, 0xf}, ColumnMeta: []uint16{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x180, 0x180, 0x0, 0x180, 0x180, 0x0, 0x2fd}, NullBitmap: []uint8{0xe0, 0x17}}, } - _, err := parser.parse([]byte{ + _, err := parser.Parse([]byte{ /* 0x00, */ 0xc1, 0x86, 0x8e, 0x55, 0x1e, 0xa5, 0x14, 0x80, 0xa, 0x55, 0x0, 0x0, 0x0, 0x7, 0xc, 0xbf, 0xe, 0x0, 0x0, 0x5f, 0x6, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x2, 0x0, 0xd, 0xff, 0x0, 0x0, 0x19, 0x63, 0x7, 0x0, 0xca, 0x61, 0x5, 0x0, 0x5e, 0xf7, 0xc, 0x0, 0xf5, 0x7, diff --git a/vendor/github.com/siddontang/go-mysql/replication/replication_test.go b/vendor/github.com/siddontang/go-mysql/replication/replication_test.go index 8af37be..50fd1ee 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/replication_test.go +++ b/vendor/github.com/siddontang/go-mysql/replication/replication_test.go @@ -1,15 +1,15 @@ package replication import ( + "context" "flag" "fmt" "os" + "path" "sync" "testing" "time" - "golang.org/x/net/context" - . "github.com/pingcap/check" uuid "github.com/satori/go.uuid" "github.com/siddontang/go-mysql/client" @@ -230,6 +230,28 @@ func (t *testSyncerSuite) testSync(c *C, s *BinlogStreamer) { } } + str = `DROP TABLE IF EXISTS test_parse_time` + t.testExecute(c, str) + + // Must allow zero time. + t.testExecute(c, `SET sql_mode=''`) + str = `CREATE TABLE test_parse_time ( + a1 DATETIME, + a2 DATETIME(3), + a3 DATETIME(6), + b1 TIMESTAMP, + b2 TIMESTAMP(3) , + b3 TIMESTAMP(6))` + t.testExecute(c, str) + + t.testExecute(c, `INSERT INTO test_parse_time VALUES + ("2014-09-08 17:51:04.123456", "2014-09-08 17:51:04.123456", "2014-09-08 17:51:04.123456", + "2014-09-08 17:51:04.123456","2014-09-08 17:51:04.123456","2014-09-08 17:51:04.123456"), + ("0000-00-00 00:00:00.000000", "0000-00-00 00:00:00.000000", "0000-00-00 00:00:00.000000", + "0000-00-00 00:00:00.000000", "0000-00-00 00:00:00.000000", "0000-00-00 00:00:00.000000"), + ("2014-09-08 17:51:04.000456", "2014-09-08 17:51:04.000456", "2014-09-08 17:51:04.000456", + "2014-09-08 17:51:04.000456","2014-09-08 17:51:04.000456","2014-09-08 17:51:04.000456")`) + t.wg.Wait() } @@ -263,15 +285,16 @@ func (t *testSyncerSuite) setupTest(c *C, flavor string) { } cfg := BinlogSyncerConfig{ - ServerID: 100, - Flavor: flavor, - Host: *testHost, - Port: port, - User: "root", - Password: "", + ServerID: 100, + Flavor: flavor, + Host: *testHost, + Port: port, + User: "root", + Password: "", + UseDecimal: true, } - t.b = NewBinlogSyncer(&cfg) + t.b = NewBinlogSyncer(cfg) } func (t *testSyncerSuite) testPositionSync(c *C) { @@ -281,7 +304,7 @@ func (t *testSyncerSuite) testPositionSync(c *C) { binFile, _ := r.GetString(0, 0) binPos, _ := r.GetInt(0, 1) - s, err := t.b.StartSync(mysql.Position{binFile, uint32(binPos)}) + s, err := t.b.StartSync(mysql.Position{Name: binFile, Pos: uint32(binPos)}) c.Assert(err, IsNil) // Test re-sync. @@ -373,12 +396,15 @@ func (t *testSyncerSuite) TestMysqlBinlogCodec(c *C) { t.testSync(c, nil) }() - os.RemoveAll("./var") + binlogDir := "./var" - err := t.b.StartBackup("./var", mysql.Position{"", uint32(0)}, 2*time.Second) + os.RemoveAll(binlogDir) + + err := t.b.StartBackup(binlogDir, mysql.Position{Name: "", Pos: uint32(0)}, 2*time.Second) c.Assert(err, IsNil) p := NewBinlogParser() + p.SetVerifyChecksum(true) f := func(e *BinlogEvent) error { if *testOutputLogs { @@ -388,9 +414,15 @@ func (t *testSyncerSuite) TestMysqlBinlogCodec(c *C) { return nil } - err = p.ParseFile("./var/mysql.000001", 0, f) + dir, err := os.Open(binlogDir) + c.Assert(err, IsNil) + defer dir.Close() + + files, err := dir.Readdirnames(-1) c.Assert(err, IsNil) - err = p.ParseFile("./var/mysql.000002", 0, f) - c.Assert(err, IsNil) + for _, file := range files { + err = p.ParseFile(path.Join(binlogDir, file), 0, f) + c.Assert(err, IsNil) + } } diff --git a/vendor/github.com/siddontang/go-mysql/replication/row_event.go b/vendor/github.com/siddontang/go-mysql/replication/row_event.go index c30d9ae..9172f6e 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/row_event.go +++ b/vendor/github.com/siddontang/go-mysql/replication/row_event.go @@ -10,11 +10,14 @@ import ( "time" "github.com/juju/errors" - "github.com/ngaut/log" + "github.com/shopspring/decimal" + "github.com/siddontang/go-log/log" . "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go/hack" ) +var errMissingTableMapEvent = errors.New("invalid table id, no corresponding table map event") + type TableMapEvent struct { tableIDSize int @@ -68,7 +71,7 @@ func (e *TableMapEvent) Decode(data []byte) error { var err error var metaData []byte - if metaData, _, n, err = LengthEnodedString(data[pos:]); err != nil { + if metaData, _, n, err = LengthEncodedString(data[pos:]); err != nil { return errors.Trace(err) } @@ -78,11 +81,14 @@ func (e *TableMapEvent) Decode(data []byte) error { pos += n - if len(data[pos:]) != bitmapByteSize(int(e.ColumnCount)) { + nullBitmapSize := bitmapByteSize(int(e.ColumnCount)) + if len(data[pos:]) < nullBitmapSize { return io.EOF } - e.NullBitmap = data[pos:] + e.NullBitmap = data[pos : pos+nullBitmapSize] + + // TODO: handle optional field meta return nil } @@ -223,6 +229,10 @@ type RowsEvent struct { //rows: invalid: int64, float64, bool, []byte, string Rows [][]interface{} + + parseTime bool + timestampStringLocation *time.Location + useDecimal bool } func (e *RowsEvent) Decode(data []byte) error { @@ -257,7 +267,11 @@ func (e *RowsEvent) Decode(data []byte) error { var ok bool e.Table, ok = e.tables[e.TableID] if !ok { - return errors.Errorf("invalid table id %d, no correspond table map event", e.TableID) + if len(e.tables) > 0 { + return errors.Errorf("invalid table id %d, no corresponding table map event", e.TableID) + } else { + return errors.Annotatef(errMissingTableMapEvent, "table id %d", e.TableID) + } } var err error @@ -336,6 +350,21 @@ func (e *RowsEvent) decodeRows(data []byte, table *TableMapEvent, bitmap []byte) return pos, nil } +func (e *RowsEvent) parseFracTime(t interface{}) interface{} { + v, ok := t.(fracTime) + if !ok { + return t + } + + if !e.parseTime { + // Don't parse time, return string directly + return v.String() + } + + // return Golang time directly + return v.Time +} + // see mysql sql/log_event.cc log_event_print_value func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{}, n int, err error) { var length int = 0 @@ -378,7 +407,7 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{ case MYSQL_TYPE_NEWDECIMAL: prec := uint8(meta >> 8) scale := uint8(meta & 0xFF) - v, n, err = decodeDecimal(data, int(prec), int(scale)) + v, n, err = decodeDecimal(data, int(prec), int(scale), e.useDecimal) case MYSQL_TYPE_FLOAT: n = 4 v = ParseBinaryFloat32(data) @@ -396,10 +425,15 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{ t := binary.LittleEndian.Uint32(data) v = time.Unix(int64(t), 0) case MYSQL_TYPE_TIMESTAMP2: - v, n, err = decodeTimestamp2(data, meta) + v, n, err = decodeTimestamp2(data, meta, e.timestampStringLocation) + //v = e.parseFracTime(v) case MYSQL_TYPE_DATETIME: n = 8 i64 := binary.LittleEndian.Uint64(data) + + if i64 == 0 { // commented by Shlomi Noach. Yes I know about `git blame` + return "0000-00-00 00:00:00", n, nil + } d := i64 / 1000000 t := i64 % 1000000 v = time.Date(int(d/10000), @@ -412,6 +446,7 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{ time.UTC).Format(TimeFormat) case MYSQL_TYPE_DATETIME2: v, n, err = decodeDatetime2(data, meta) + v = e.parseFracTime(v) case MYSQL_TYPE_TIME: n = 3 i32 := uint32(FixedLengthInt(data[0:3])) @@ -445,7 +480,7 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{ v = int64(data[0]) n = 1 case 2: - v = int64(binary.BigEndian.Uint16(data)) + v = int64(binary.LittleEndian.Uint16(data)) n = 2 default: err = fmt.Errorf("Unknown ENUM packlen=%d", l) @@ -454,7 +489,7 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{ n = int(meta & 0xFF) nbits := n * 8 - v, err = decodeBit(data, nbits, n) + v, err = littleDecodeBit(data, nbits, n) case MYSQL_TYPE_BLOB: v, n, err = decodeBlob(data, meta) case MYSQL_TYPE_VARCHAR, @@ -464,10 +499,10 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{ case MYSQL_TYPE_STRING: v, n = decodeString(data, length) case MYSQL_TYPE_JSON: - // Refer https://github.com/shyiko/mysql-binlog-connector-java/blob/8f9132ee773317e00313204beeae8ddcaa43c1b4/src/main/java/com/github/shyiko/mysql/binlog/event/deserialization/AbstractRowsEventDataDeserializer.java#L344 - length = int(binary.LittleEndian.Uint32(data[0:])) + // Refer: https://github.com/shyiko/mysql-binlog-connector-java/blob/master/src/main/java/com/github/shyiko/mysql/binlog/event/deserialization/AbstractRowsEventDataDeserializer.java#L404 + length = int(FixedLengthInt(data[0:meta])) n = length + int(meta) - v, err = decodeJsonBinary(data[meta:n]) + v, err = e.decodeJsonBinary(data[meta:n]) case MYSQL_TYPE_GEOMETRY: // MySQL saves Geometry as Blob in binlog // Seem that the binary format is SRID (4 bytes) + WKB, outer can use @@ -511,7 +546,7 @@ func decodeDecimalDecompressValue(compIndx int, data []byte, mask uint8) (size i return } -func decodeDecimal(data []byte, precision int, decimals int) (float64, int, error) { +func decodeDecimal(data []byte, precision int, decimals int, useDecimal bool) (interface{}, int, error) { //see python mysql replication and https://github.com/jeremycole/mysql_binlog integral := (precision - decimals) uncompIntegral := int(integral / digitsPerInteger) @@ -564,6 +599,11 @@ func decodeDecimal(data []byte, precision int, decimals int) (float64, int, erro pos += size } + if useDecimal { + f, err := decimal.NewFromString(hack.String(res.Bytes())) + return f, pos, err + } + f, err := strconv.ParseFloat(hack.String(res.Bytes()), 64) return f, pos, err } @@ -600,7 +640,39 @@ func decodeBit(data []byte, nbits int, length int) (value int64, err error) { return } -func decodeTimestamp2(data []byte, dec uint16) (interface{}, int, error) { +func littleDecodeBit(data []byte, nbits int, length int) (value int64, err error) { + if nbits > 1 { + switch length { + case 1: + value = int64(data[0]) + case 2: + value = int64(binary.LittleEndian.Uint16(data)) + case 3: + value = int64(FixedLengthInt(data[0:3])) + case 4: + value = int64(binary.LittleEndian.Uint32(data)) + case 5: + value = int64(FixedLengthInt(data[0:5])) + case 6: + value = int64(FixedLengthInt(data[0:6])) + case 7: + value = int64(FixedLengthInt(data[0:7])) + case 8: + value = int64(binary.LittleEndian.Uint64(data)) + default: + err = fmt.Errorf("invalid bit length %d", length) + } + } else { + if length != 1 { + err = fmt.Errorf("invalid bit length %d", length) + } else { + value = int64(data[0]) + } + } + return +} + +func decodeTimestamp2(data []byte, dec uint16, timestampStringLocation *time.Location) (interface{}, int, error) { //get timestamp binary length n := int(4 + (dec+1)/2) sec := int64(binary.BigEndian.Uint32(data[0:4])) @@ -615,7 +687,7 @@ func decodeTimestamp2(data []byte, dec uint16) (interface{}, int, error) { } if sec == 0 { - return "0000-00-00 00:00:00", n, nil + return formatZeroTime(int(usec), int(dec)), n, nil } t := time.Unix(sec, usec*1000) @@ -641,7 +713,7 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { } if intPart == 0 { - return "0000-00-00 00:00:00", n, nil + return formatZeroTime(int(frac), int(dec)), n, nil } tmp := intPart<<24 + frac @@ -650,7 +722,7 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { tmp = -tmp } - var secPart int64 = tmp % (1 << 24) + // var secPart int64 = tmp % (1 << 24) ymdhms := tmp >> 24 ymd := ymdhms >> 17 @@ -665,10 +737,14 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { minute := int((hms >> 6) % (1 << 6)) hour := int((hms >> 12)) - if secPart != 0 { - return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%06d", year, month, day, hour, minute, second, secPart), n, nil // commented by Shlomi Noach. Yes I know about `git blame` + if frac != 0 { + return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%06d", year, month, day, hour, minute, second, frac), n, nil // commented by Shlomi Noach. Yes I know about `git blame` } return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second), n, nil // commented by Shlomi Noach. Yes I know about `git blame` + // return fracTime{ + // Time: time.Date(year, time.Month(month), day, hour, minute, second, int(frac*1000), time.UTC), + // Dec: int(dec), + // }, n, nil } const TIMEF_OFS int64 = 0x800000000000 diff --git a/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go b/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go index 2d63bb5..533a876 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go +++ b/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go @@ -2,8 +2,10 @@ package replication import ( "fmt" + "strconv" . "github.com/pingcap/check" + "github.com/shopspring/decimal" ) type testDecodeSuite struct{} @@ -17,10 +19,10 @@ type decodeDecimalChecker struct { func (_ *decodeDecimalChecker) Check(params []interface{}, names []string) (bool, string) { var test int val := struct { - Value float64 + Value decimal.Decimal Pos int Err error - EValue float64 + EValue decimal.Decimal EPos int EErr error }{} @@ -28,13 +30,13 @@ func (_ *decodeDecimalChecker) Check(params []interface{}, names []string) (bool for i, name := range names { switch name { case "obtainedValue": - val.Value, _ = params[i].(float64) + val.Value, _ = params[i].(decimal.Decimal) case "obtainedPos": val.Pos, _ = params[i].(int) case "obtainedErr": val.Err, _ = params[i].(error) case "expectedValue": - val.EValue, _ = params[i].(float64) + val.EValue, _ = params[i].(decimal.Decimal) case "expectedPos": val.EPos, _ = params[i].(int) case "expectedErr": @@ -50,7 +52,7 @@ func (_ *decodeDecimalChecker) Check(params []interface{}, names []string) (bool if val.Pos != val.EPos { return false, fmt.Sprintf(errorMsgFmt, "position", val.EPos, val.Pos) } - if val.Value != val.EValue { + if !val.Value.Equal(val.EValue) { return false, fmt.Sprintf(errorMsgFmt, "value", val.EValue, val.Value) } return true, "" @@ -66,7 +68,7 @@ func (_ *testDecodeSuite) TestDecodeDecimal(c *C) { Data []byte Precision int Decimals int - Expected float64 + Expected string ExpectedPos int ExpectedErr error }{ @@ -133,197 +135,202 @@ func (_ *testDecodeSuite) TestDecodeDecimal(c *C) { | 17 | -99.99 | -1948 | -1948.140 | -1948.14 | -1948.140 | -1948.14 | -9.99999999999999 | -1948.1400000000 | -1948.14000 | -1948.14000000000000000000 | -1948.1400000000000000000000000 | 13 | 2 | +----+--------+-------+-----------+-------------+-------------+----------------+-------------------+-----------------------+---------------------+---------------------------------+---------------------------------+------+-------+ */ - {[]byte{117, 200, 127, 255}, 4, 2, float64(-10.55), 2, nil}, - {[]byte{127, 255, 244, 127, 245}, 5, 0, float64(-11), 3, nil}, - {[]byte{127, 245, 253, 217, 127, 255}, 7, 3, float64(-10.550), 4, nil}, - {[]byte{127, 255, 255, 245, 200, 127, 255}, 10, 2, float64(-10.55), 5, nil}, - {[]byte{127, 255, 255, 245, 253, 217, 127, 255}, 10, 3, float64(-10.550), 6, nil}, - {[]byte{127, 255, 255, 255, 245, 200, 118, 196}, 13, 2, float64(-10.55), 6, nil}, - {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, float64(-9.99999999999999), 8, nil}, - {[]byte{127, 255, 255, 255, 245, 223, 55, 170, 127, 255, 127, 255}, 20, 10, float64(-10.5500000000), 10, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 245, 255, 41, 39, 127, 255}, 30, 5, float64(-10.55000), 15, nil}, - {[]byte{127, 255, 255, 255, 245, 223, 55, 170, 127, 255, 255, 255, 255, 255, 127, 255}, 30, 20, float64(-10.55000000000000000000), 14, nil}, - {[]byte{127, 255, 245, 223, 55, 170, 127, 255, 255, 255, 255, 255, 255, 255, 255, 4, 0}, 30, 25, float64(-10.5500000000000000000000000), 15, nil}, - {[]byte{128, 1, 128, 0}, 4, 2, float64(0.01), 2, nil}, - {[]byte{128, 0, 0, 128, 0}, 5, 0, float64(0), 3, nil}, - {[]byte{128, 0, 0, 12, 128, 0}, 7, 3, float64(0.012), 4, nil}, - {[]byte{128, 0, 0, 0, 1, 128, 0}, 10, 2, float64(0.01), 5, nil}, - {[]byte{128, 0, 0, 0, 0, 12, 128, 0}, 10, 3, float64(0.012), 6, nil}, - {[]byte{128, 0, 0, 0, 0, 1, 128, 0}, 13, 2, float64(0.01), 6, nil}, - {[]byte{128, 0, 188, 97, 78, 1, 96, 11, 128, 0}, 15, 14, float64(0.01234567890123), 8, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 188, 97, 78, 9, 128, 0}, 20, 10, float64(0.0123456789), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 211, 128, 0}, 30, 5, float64(0.01235), 15, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 188, 97, 78, 53, 183, 191, 135, 89, 128, 0}, 30, 20, float64(0.01234567890123456789), 14, nil}, - {[]byte{128, 0, 0, 0, 188, 97, 78, 53, 183, 191, 135, 0, 135, 253, 217, 30, 0}, 30, 25, float64(0.0123456789012345678912345), 15, nil}, - {[]byte{227, 99, 128, 48}, 4, 2, float64(99.99), 2, nil}, - {[]byte{128, 48, 57, 167, 15}, 5, 0, float64(12345), 3, nil}, - {[]byte{167, 15, 3, 231, 128, 0}, 7, 3, float64(9999.999), 4, nil}, - {[]byte{128, 0, 48, 57, 0, 128, 0}, 10, 2, float64(12345.00), 5, nil}, - {[]byte{128, 0, 48, 57, 0, 0, 128, 0}, 10, 3, float64(12345.000), 6, nil}, - {[]byte{128, 0, 0, 48, 57, 0, 137, 59}, 13, 2, float64(12345.00), 6, nil}, - {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, float64(9.99999999999999), 8, nil}, - {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 128, 0}, 20, 10, float64(12345.0000000000), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 57, 0, 0, 0, 128, 0}, 30, 5, float64(12345.00000), 15, nil}, - {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 48}, 30, 20, float64(12345.00000000000000000000), 14, nil}, - {[]byte{128, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0}, 30, 25, float64(12345.0000000000000000000000000), 15, nil}, - {[]byte{227, 99, 128, 48}, 4, 2, float64(99.99), 2, nil}, - {[]byte{128, 48, 57, 167, 15}, 5, 0, float64(12345), 3, nil}, - {[]byte{167, 15, 3, 231, 128, 0}, 7, 3, float64(9999.999), 4, nil}, - {[]byte{128, 0, 48, 57, 0, 128, 0}, 10, 2, float64(12345.00), 5, nil}, - {[]byte{128, 0, 48, 57, 0, 0, 128, 0}, 10, 3, float64(12345.000), 6, nil}, - {[]byte{128, 0, 0, 48, 57, 0, 137, 59}, 13, 2, float64(12345.00), 6, nil}, - {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, float64(9.99999999999999), 8, nil}, - {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 128, 0}, 20, 10, float64(12345.0000000000), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 57, 0, 0, 0, 128, 0}, 30, 5, float64(12345.00000), 15, nil}, - {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 48}, 30, 20, float64(12345.00000000000000000000), 14, nil}, - {[]byte{128, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0}, 30, 25, float64(12345.0000000000000000000000000), 15, nil}, - {[]byte{227, 99, 128, 0}, 4, 2, float64(99.99), 2, nil}, - {[]byte{128, 0, 123, 128, 123}, 5, 0, float64(123), 3, nil}, - {[]byte{128, 123, 1, 194, 128, 0}, 7, 3, float64(123.450), 4, nil}, - {[]byte{128, 0, 0, 123, 45, 128, 0}, 10, 2, float64(123.45), 5, nil}, - {[]byte{128, 0, 0, 123, 1, 194, 128, 0}, 10, 3, float64(123.450), 6, nil}, - {[]byte{128, 0, 0, 0, 123, 45, 137, 59}, 13, 2, float64(123.45), 6, nil}, - {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, float64(9.99999999999999), 8, nil}, - {[]byte{128, 0, 0, 0, 123, 26, 210, 116, 128, 0, 128, 0}, 20, 10, float64(123.4500000000), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123, 0, 175, 200, 128, 0}, 30, 5, float64(123.45000), 15, nil}, - {[]byte{128, 0, 0, 0, 123, 26, 210, 116, 128, 0, 0, 0, 0, 0, 128, 0}, 30, 20, float64(123.45000000000000000000), 14, nil}, - {[]byte{128, 0, 123, 26, 210, 116, 128, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0}, 30, 25, float64(123.4500000000000000000000000), 15, nil}, - {[]byte{28, 156, 127, 255}, 4, 2, float64(-99.99), 2, nil}, - {[]byte{127, 255, 132, 127, 132}, 5, 0, float64(-123), 3, nil}, - {[]byte{127, 132, 254, 61, 127, 255}, 7, 3, float64(-123.450), 4, nil}, - {[]byte{127, 255, 255, 132, 210, 127, 255}, 10, 2, float64(-123.45), 5, nil}, - {[]byte{127, 255, 255, 132, 254, 61, 127, 255}, 10, 3, float64(-123.450), 6, nil}, - {[]byte{127, 255, 255, 255, 132, 210, 118, 196}, 13, 2, float64(-123.45), 6, nil}, - {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, float64(-9.99999999999999), 8, nil}, - {[]byte{127, 255, 255, 255, 132, 229, 45, 139, 127, 255, 127, 255}, 20, 10, float64(-123.4500000000), 10, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 132, 255, 80, 55, 127, 255}, 30, 5, float64(-123.45000), 15, nil}, - {[]byte{127, 255, 255, 255, 132, 229, 45, 139, 127, 255, 255, 255, 255, 255, 127, 255}, 30, 20, float64(-123.45000000000000000000), 14, nil}, - {[]byte{127, 255, 132, 229, 45, 139, 127, 255, 255, 255, 255, 255, 255, 255, 255, 20, 0}, 30, 25, float64(-123.4500000000000000000000000), 15, nil}, - {[]byte{128, 0, 128, 0}, 4, 2, float64(0.00), 2, nil}, - {[]byte{128, 0, 0, 128, 0}, 5, 0, float64(0), 3, nil}, - {[]byte{128, 0, 0, 0, 128, 0}, 7, 3, float64(0.000), 4, nil}, - {[]byte{128, 0, 0, 0, 0, 128, 0}, 10, 2, float64(0.00), 5, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 10, 3, float64(0.000), 6, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 13, 2, float64(0.00), 6, nil}, - {[]byte{128, 0, 1, 226, 58, 0, 0, 99, 128, 0}, 15, 14, float64(0.00012345000099), 8, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 128, 0}, 20, 10, float64(0.0001234500), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 128, 0}, 30, 5, float64(0.00012), 15, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 128, 0}, 30, 20, float64(0.00012345000098765000), 14, nil}, - {[]byte{128, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 0, 0, 0, 15, 0}, 30, 25, float64(0.0001234500009876500000000), 15, nil}, - {[]byte{128, 0, 128, 0}, 4, 2, float64(0.00), 2, nil}, - {[]byte{128, 0, 0, 128, 0}, 5, 0, float64(0), 3, nil}, - {[]byte{128, 0, 0, 0, 128, 0}, 7, 3, float64(0.000), 4, nil}, - {[]byte{128, 0, 0, 0, 0, 128, 0}, 10, 2, float64(0.00), 5, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 10, 3, float64(0.000), 6, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 13, 2, float64(0.00), 6, nil}, - {[]byte{128, 0, 1, 226, 58, 0, 0, 99, 128, 0}, 15, 14, float64(0.00012345000099), 8, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 128, 0}, 20, 10, float64(0.0001234500), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 128, 0}, 30, 5, float64(0.00012), 15, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 128, 0}, 30, 20, float64(0.00012345000098765000), 14, nil}, - {[]byte{128, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 0, 0, 0, 22, 0}, 30, 25, float64(0.0001234500009876500000000), 15, nil}, - {[]byte{128, 12, 128, 0}, 4, 2, float64(0.12), 2, nil}, - {[]byte{128, 0, 0, 128, 0}, 5, 0, float64(0), 3, nil}, - {[]byte{128, 0, 0, 123, 128, 0}, 7, 3, float64(0.123), 4, nil}, - {[]byte{128, 0, 0, 0, 12, 128, 0}, 10, 2, float64(0.12), 5, nil}, - {[]byte{128, 0, 0, 0, 0, 123, 128, 0}, 10, 3, float64(0.123), 6, nil}, - {[]byte{128, 0, 0, 0, 0, 12, 128, 7}, 13, 2, float64(0.12), 6, nil}, - {[]byte{128, 7, 91, 178, 144, 1, 129, 205, 128, 0}, 15, 14, float64(0.12345000098765), 8, nil}, - {[]byte{128, 0, 0, 0, 0, 7, 91, 178, 145, 0, 128, 0}, 20, 10, float64(0.1234500010), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 57, 128, 0}, 30, 5, float64(0.12345), 15, nil}, - {[]byte{128, 0, 0, 0, 0, 7, 91, 178, 144, 58, 222, 87, 208, 0, 128, 0}, 30, 20, float64(0.12345000098765000000), 14, nil}, - {[]byte{128, 0, 0, 7, 91, 178, 144, 58, 222, 87, 208, 0, 0, 0, 0, 30, 0}, 30, 25, float64(0.1234500009876500000000000), 15, nil}, - {[]byte{128, 0, 128, 0}, 4, 2, float64(0.00), 2, nil}, - {[]byte{128, 0, 0, 128, 0}, 5, 0, float64(0), 3, nil}, - {[]byte{128, 0, 0, 0, 128, 0}, 7, 3, float64(0.000), 4, nil}, - {[]byte{128, 0, 0, 0, 0, 128, 0}, 10, 2, float64(0.00), 5, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 10, 3, float64(0.000), 6, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 127, 255}, 13, 2, float64(0.00), 6, nil}, - {[]byte{127, 255, 255, 255, 243, 255, 121, 59, 127, 255}, 15, 14, float64(-0.00000001234500), 8, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 243, 252, 128, 0}, 20, 10, float64(-0.0000000123), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255}, 30, 5, float64(0.00000), 15, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 243, 235, 111, 183, 93, 178, 127, 255}, 30, 20, float64(-0.00000001234500009877), 14, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 243, 235, 111, 183, 93, 255, 139, 69, 47, 30, 0}, 30, 25, float64(-0.0000000123450000987650000), 15, nil}, - {[]byte{227, 99, 129, 134}, 4, 2, float64(99.99), 2, nil}, - {[]byte{129, 134, 159, 167, 15}, 5, 0, float64(99999), 3, nil}, - {[]byte{167, 15, 3, 231, 133, 245}, 7, 3, float64(9999.999), 4, nil}, - {[]byte{133, 245, 224, 255, 99, 128, 152}, 10, 2, float64(99999999.99), 5, nil}, - {[]byte{128, 152, 150, 127, 3, 231, 227, 59}, 10, 3, float64(9999999.999), 6, nil}, - {[]byte{227, 59, 154, 201, 255, 99, 137, 59}, 13, 2, float64(99999999999.99), 6, nil}, - {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 137, 59}, 15, 14, float64(9.99999999999999), 8, nil}, - {[]byte{137, 59, 154, 201, 255, 59, 154, 201, 255, 9, 128, 0}, 20, 10, float64(9999999999.9999999999), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 4, 210, 29, 205, 139, 148, 0, 195, 80, 137, 59}, 30, 5, float64(1234500009876.50000), 15, nil}, - {[]byte{137, 59, 154, 201, 255, 59, 154, 201, 255, 59, 154, 201, 255, 99, 129, 134}, 30, 20, float64(9999999999.99999999999999999999), 14, nil}, - {[]byte{129, 134, 159, 59, 154, 201, 255, 59, 154, 201, 255, 0, 152, 150, 127, 30, 0}, 30, 25, float64(99999.9999999999999999999999999), 15, nil}, - {[]byte{227, 99, 129, 134}, 4, 2, float64(99.99), 2, nil}, - {[]byte{129, 134, 159, 167, 15}, 5, 0, float64(99999), 3, nil}, - {[]byte{167, 15, 3, 231, 133, 245}, 7, 3, float64(9999.999), 4, nil}, - {[]byte{133, 245, 224, 255, 99, 128, 152}, 10, 2, float64(99999999.99), 5, nil}, - {[]byte{128, 152, 150, 127, 3, 231, 128, 6}, 10, 3, float64(9999999.999), 6, nil}, - {[]byte{128, 6, 159, 107, 199, 11, 137, 59}, 13, 2, float64(111111111.11), 6, nil}, - {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 6}, 15, 14, float64(9.99999999999999), 8, nil}, - {[]byte{128, 6, 159, 107, 199, 6, 142, 119, 128, 0, 128, 0}, 20, 10, float64(111111111.1100000000), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 6, 159, 107, 199, 0, 42, 248, 128, 6}, 30, 5, float64(111111111.11000), 15, nil}, - {[]byte{128, 6, 159, 107, 199, 6, 142, 119, 128, 0, 0, 0, 0, 0, 129, 134}, 30, 20, float64(111111111.11000000000000000000), 14, nil}, - {[]byte{129, 134, 159, 59, 154, 201, 255, 59, 154, 201, 255, 0, 152, 150, 127, 10, 0}, 30, 25, float64(99999.9999999999999999999999999), 15, nil}, - {[]byte{128, 1, 128, 0}, 4, 2, float64(0.01), 2, nil}, - {[]byte{128, 0, 0, 128, 0}, 5, 0, float64(0), 3, nil}, - {[]byte{128, 0, 0, 10, 128, 0}, 7, 3, float64(0.010), 4, nil}, - {[]byte{128, 0, 0, 0, 1, 128, 0}, 10, 2, float64(0.01), 5, nil}, - {[]byte{128, 0, 0, 0, 0, 10, 128, 0}, 10, 3, float64(0.010), 6, nil}, - {[]byte{128, 0, 0, 0, 0, 1, 128, 0}, 13, 2, float64(0.01), 6, nil}, - {[]byte{128, 0, 152, 150, 128, 0, 0, 0, 128, 0}, 15, 14, float64(0.01000000000000), 8, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 152, 150, 128, 0, 128, 0}, 20, 10, float64(0.0100000000), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 232, 128, 0}, 30, 5, float64(0.01000), 15, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 152, 150, 128, 0, 0, 0, 0, 0, 128, 0}, 30, 20, float64(0.01000000000000000000), 14, nil}, - {[]byte{128, 0, 0, 0, 152, 150, 128, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0}, 30, 25, float64(0.0100000000000000000000000), 15, nil}, - {[]byte{227, 99, 128, 0}, 4, 2, float64(99.99), 2, nil}, - {[]byte{128, 0, 123, 128, 123}, 5, 0, float64(123), 3, nil}, - {[]byte{128, 123, 1, 144, 128, 0}, 7, 3, float64(123.400), 4, nil}, - {[]byte{128, 0, 0, 123, 40, 128, 0}, 10, 2, float64(123.40), 5, nil}, - {[]byte{128, 0, 0, 123, 1, 144, 128, 0}, 10, 3, float64(123.400), 6, nil}, - {[]byte{128, 0, 0, 0, 123, 40, 137, 59}, 13, 2, float64(123.40), 6, nil}, - {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, float64(9.99999999999999), 8, nil}, - {[]byte{128, 0, 0, 0, 123, 23, 215, 132, 0, 0, 128, 0}, 20, 10, float64(123.4000000000), 10, nil}, - {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123, 0, 156, 64, 128, 0}, 30, 5, float64(123.40000), 15, nil}, - {[]byte{128, 0, 0, 0, 123, 23, 215, 132, 0, 0, 0, 0, 0, 0, 128, 0}, 30, 20, float64(123.40000000000000000000), 14, nil}, - {[]byte{128, 0, 123, 23, 215, 132, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0}, 30, 25, float64(123.4000000000000000000000000), 15, nil}, - {[]byte{28, 156, 127, 253}, 4, 2, float64(-99.99), 2, nil}, - {[]byte{127, 253, 204, 125, 205}, 5, 0, float64(-563), 3, nil}, - {[]byte{125, 205, 253, 187, 127, 255}, 7, 3, float64(-562.580), 4, nil}, - {[]byte{127, 255, 253, 205, 197, 127, 255}, 10, 2, float64(-562.58), 5, nil}, - {[]byte{127, 255, 253, 205, 253, 187, 127, 255}, 10, 3, float64(-562.580), 6, nil}, - {[]byte{127, 255, 255, 253, 205, 197, 118, 196}, 13, 2, float64(-562.58), 6, nil}, - {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, float64(-9.99999999999999), 8, nil}, - {[]byte{127, 255, 255, 253, 205, 221, 109, 230, 255, 255, 127, 255}, 20, 10, float64(-562.5800000000), 10, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 253, 205, 255, 29, 111, 127, 255}, 30, 5, float64(-562.58000), 15, nil}, - {[]byte{127, 255, 255, 253, 205, 221, 109, 230, 255, 255, 255, 255, 255, 255, 127, 253}, 30, 20, float64(-562.58000000000000000000), 14, nil}, - {[]byte{127, 253, 205, 221, 109, 230, 255, 255, 255, 255, 255, 255, 255, 255, 255, 13, 0}, 30, 25, float64(-562.5800000000000000000000000), 15, nil}, - {[]byte{28, 156, 127, 241}, 4, 2, float64(-99.99), 2, nil}, - {[]byte{127, 241, 140, 113, 140}, 5, 0, float64(-3699), 3, nil}, - {[]byte{113, 140, 255, 245, 127, 255}, 7, 3, float64(-3699.010), 4, nil}, - {[]byte{127, 255, 241, 140, 254, 127, 255}, 10, 2, float64(-3699.01), 5, nil}, - {[]byte{127, 255, 241, 140, 255, 245, 127, 255}, 10, 3, float64(-3699.010), 6, nil}, - {[]byte{127, 255, 255, 241, 140, 254, 118, 196}, 13, 2, float64(-3699.01), 6, nil}, - {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, float64(-9.99999999999999), 8, nil}, - {[]byte{127, 255, 255, 241, 140, 255, 103, 105, 127, 255, 127, 255}, 20, 10, float64(-3699.0100000000), 10, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 241, 140, 255, 252, 23, 127, 255}, 30, 5, float64(-3699.01000), 15, nil}, - {[]byte{127, 255, 255, 241, 140, 255, 103, 105, 127, 255, 255, 255, 255, 255, 127, 241}, 30, 20, float64(-3699.01000000000000000000), 14, nil}, - {[]byte{127, 241, 140, 255, 103, 105, 127, 255, 255, 255, 255, 255, 255, 255, 255, 13, 0}, 30, 25, float64(-3699.0100000000000000000000000), 15, nil}, - {[]byte{28, 156, 127, 248}, 4, 2, float64(-99.99), 2, nil}, - {[]byte{127, 248, 99, 120, 99}, 5, 0, float64(-1948), 3, nil}, - {[]byte{120, 99, 255, 115, 127, 255}, 7, 3, float64(-1948.140), 4, nil}, - {[]byte{127, 255, 248, 99, 241, 127, 255}, 10, 2, float64(-1948.14), 5, nil}, - {[]byte{127, 255, 248, 99, 255, 115, 127, 255}, 10, 3, float64(-1948.140), 6, nil}, - {[]byte{127, 255, 255, 248, 99, 241, 118, 196}, 13, 2, float64(-1948.14), 6, nil}, - {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, float64(-9.99999999999999), 8, nil}, - {[]byte{127, 255, 255, 248, 99, 247, 167, 196, 255, 255, 127, 255}, 20, 10, float64(-1948.1400000000), 10, nil}, - {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 248, 99, 255, 201, 79, 127, 255}, 30, 5, float64(-1948.14000), 15, nil}, - {[]byte{127, 255, 255, 248, 99, 247, 167, 196, 255, 255, 255, 255, 255, 255, 127, 248}, 30, 20, float64(-1948.14000000000000000000), 14, nil}, - {[]byte{127, 248, 99, 247, 167, 196, 255, 255, 255, 255, 255, 255, 255, 255, 255, 13, 0}, 30, 25, float64(-1948.1400000000000000000000000), 15, nil}, + {[]byte{117, 200, 127, 255}, 4, 2, "-10.55", 2, nil}, + {[]byte{127, 255, 244, 127, 245}, 5, 0, "-11", 3, nil}, + {[]byte{127, 245, 253, 217, 127, 255}, 7, 3, "-10.550", 4, nil}, + {[]byte{127, 255, 255, 245, 200, 127, 255}, 10, 2, "-10.55", 5, nil}, + {[]byte{127, 255, 255, 245, 253, 217, 127, 255}, 10, 3, "-10.550", 6, nil}, + {[]byte{127, 255, 255, 255, 245, 200, 118, 196}, 13, 2, "-10.55", 6, nil}, + {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, "-9.99999999999999", 8, nil}, + {[]byte{127, 255, 255, 255, 245, 223, 55, 170, 127, 255, 127, 255}, 20, 10, "-10.5500000000", 10, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 245, 255, 41, 39, 127, 255}, 30, 5, "-10.55000", 15, nil}, + {[]byte{127, 255, 255, 255, 245, 223, 55, 170, 127, 255, 255, 255, 255, 255, 127, 255}, 30, 20, "-10.55000000000000000000", 14, nil}, + {[]byte{127, 255, 245, 223, 55, 170, 127, 255, 255, 255, 255, 255, 255, 255, 255, 4, 0}, 30, 25, "-10.5500000000000000000000000", 15, nil}, + {[]byte{128, 1, 128, 0}, 4, 2, "0.01", 2, nil}, + {[]byte{128, 0, 0, 128, 0}, 5, 0, "0", 3, nil}, + {[]byte{128, 0, 0, 12, 128, 0}, 7, 3, "0.012", 4, nil}, + {[]byte{128, 0, 0, 0, 1, 128, 0}, 10, 2, "0.01", 5, nil}, + {[]byte{128, 0, 0, 0, 0, 12, 128, 0}, 10, 3, "0.012", 6, nil}, + {[]byte{128, 0, 0, 0, 0, 1, 128, 0}, 13, 2, "0.01", 6, nil}, + {[]byte{128, 0, 188, 97, 78, 1, 96, 11, 128, 0}, 15, 14, "0.01234567890123", 8, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 188, 97, 78, 9, 128, 0}, 20, 10, "0.0123456789", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 211, 128, 0}, 30, 5, "0.01235", 15, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 188, 97, 78, 53, 183, 191, 135, 89, 128, 0}, 30, 20, "0.01234567890123456789", 14, nil}, + {[]byte{128, 0, 0, 0, 188, 97, 78, 53, 183, 191, 135, 0, 135, 253, 217, 30, 0}, 30, 25, "0.0123456789012345678912345", 15, nil}, + {[]byte{227, 99, 128, 48}, 4, 2, "99.99", 2, nil}, + {[]byte{128, 48, 57, 167, 15}, 5, 0, "12345", 3, nil}, + {[]byte{167, 15, 3, 231, 128, 0}, 7, 3, "9999.999", 4, nil}, + {[]byte{128, 0, 48, 57, 0, 128, 0}, 10, 2, "12345.00", 5, nil}, + {[]byte{128, 0, 48, 57, 0, 0, 128, 0}, 10, 3, "12345.000", 6, nil}, + {[]byte{128, 0, 0, 48, 57, 0, 137, 59}, 13, 2, "12345.00", 6, nil}, + {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, "9.99999999999999", 8, nil}, + {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 128, 0}, 20, 10, "12345.0000000000", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 57, 0, 0, 0, 128, 0}, 30, 5, "12345.00000", 15, nil}, + {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 48}, 30, 20, "12345.00000000000000000000", 14, nil}, + {[]byte{128, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0}, 30, 25, "12345.0000000000000000000000000", 15, nil}, + {[]byte{227, 99, 128, 48}, 4, 2, "99.99", 2, nil}, + {[]byte{128, 48, 57, 167, 15}, 5, 0, "12345", 3, nil}, + {[]byte{167, 15, 3, 231, 128, 0}, 7, 3, "9999.999", 4, nil}, + {[]byte{128, 0, 48, 57, 0, 128, 0}, 10, 2, "12345.00", 5, nil}, + {[]byte{128, 0, 48, 57, 0, 0, 128, 0}, 10, 3, "12345.000", 6, nil}, + {[]byte{128, 0, 0, 48, 57, 0, 137, 59}, 13, 2, "12345.00", 6, nil}, + {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, "9.99999999999999", 8, nil}, + {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 128, 0}, 20, 10, "12345.0000000000", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 57, 0, 0, 0, 128, 0}, 30, 5, "12345.00000", 15, nil}, + {[]byte{128, 0, 0, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 48}, 30, 20, "12345.00000000000000000000", 14, nil}, + {[]byte{128, 48, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0}, 30, 25, "12345.0000000000000000000000000", 15, nil}, + {[]byte{227, 99, 128, 0}, 4, 2, "99.99", 2, nil}, + {[]byte{128, 0, 123, 128, 123}, 5, 0, "123", 3, nil}, + {[]byte{128, 123, 1, 194, 128, 0}, 7, 3, "123.450", 4, nil}, + {[]byte{128, 0, 0, 123, 45, 128, 0}, 10, 2, "123.45", 5, nil}, + {[]byte{128, 0, 0, 123, 1, 194, 128, 0}, 10, 3, "123.450", 6, nil}, + {[]byte{128, 0, 0, 0, 123, 45, 137, 59}, 13, 2, "123.45", 6, nil}, + {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, "9.99999999999999", 8, nil}, + {[]byte{128, 0, 0, 0, 123, 26, 210, 116, 128, 0, 128, 0}, 20, 10, "123.4500000000", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123, 0, 175, 200, 128, 0}, 30, 5, "123.45000", 15, nil}, + {[]byte{128, 0, 0, 0, 123, 26, 210, 116, 128, 0, 0, 0, 0, 0, 128, 0}, 30, 20, "123.45000000000000000000", 14, nil}, + {[]byte{128, 0, 123, 26, 210, 116, 128, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0}, 30, 25, "123.4500000000000000000000000", 15, nil}, + {[]byte{28, 156, 127, 255}, 4, 2, "-99.99", 2, nil}, + {[]byte{127, 255, 132, 127, 132}, 5, 0, "-123", 3, nil}, + {[]byte{127, 132, 254, 61, 127, 255}, 7, 3, "-123.450", 4, nil}, + {[]byte{127, 255, 255, 132, 210, 127, 255}, 10, 2, "-123.45", 5, nil}, + {[]byte{127, 255, 255, 132, 254, 61, 127, 255}, 10, 3, "-123.450", 6, nil}, + {[]byte{127, 255, 255, 255, 132, 210, 118, 196}, 13, 2, "-123.45", 6, nil}, + {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, "-9.99999999999999", 8, nil}, + {[]byte{127, 255, 255, 255, 132, 229, 45, 139, 127, 255, 127, 255}, 20, 10, "-123.4500000000", 10, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 132, 255, 80, 55, 127, 255}, 30, 5, "-123.45000", 15, nil}, + {[]byte{127, 255, 255, 255, 132, 229, 45, 139, 127, 255, 255, 255, 255, 255, 127, 255}, 30, 20, "-123.45000000000000000000", 14, nil}, + {[]byte{127, 255, 132, 229, 45, 139, 127, 255, 255, 255, 255, 255, 255, 255, 255, 20, 0}, 30, 25, "-123.4500000000000000000000000", 15, nil}, + {[]byte{128, 0, 128, 0}, 4, 2, "0.00", 2, nil}, + {[]byte{128, 0, 0, 128, 0}, 5, 0, "0", 3, nil}, + {[]byte{128, 0, 0, 0, 128, 0}, 7, 3, "0.000", 4, nil}, + {[]byte{128, 0, 0, 0, 0, 128, 0}, 10, 2, "0.00", 5, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 10, 3, "0.000", 6, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 13, 2, "0.00", 6, nil}, + {[]byte{128, 0, 1, 226, 58, 0, 0, 99, 128, 0}, 15, 14, "0.00012345000099", 8, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 128, 0}, 20, 10, "0.0001234500", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 128, 0}, 30, 5, "0.00012", 15, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 128, 0}, 30, 20, "0.00012345000098765000", 14, nil}, + {[]byte{128, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 0, 0, 0, 15, 0}, 30, 25, "0.0001234500009876500000000", 15, nil}, + {[]byte{128, 0, 128, 0}, 4, 2, "0.00", 2, nil}, + {[]byte{128, 0, 0, 128, 0}, 5, 0, "0", 3, nil}, + {[]byte{128, 0, 0, 0, 128, 0}, 7, 3, "0.000", 4, nil}, + {[]byte{128, 0, 0, 0, 0, 128, 0}, 10, 2, "0.00", 5, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 10, 3, "0.000", 6, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 13, 2, "0.00", 6, nil}, + {[]byte{128, 0, 1, 226, 58, 0, 0, 99, 128, 0}, 15, 14, "0.00012345000099", 8, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 128, 0}, 20, 10, "0.0001234500", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 128, 0}, 30, 5, "0.00012", 15, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 128, 0}, 30, 20, "0.00012345000098765000", 14, nil}, + {[]byte{128, 0, 0, 0, 1, 226, 58, 0, 15, 18, 2, 0, 0, 0, 0, 22, 0}, 30, 25, "0.0001234500009876500000000", 15, nil}, + {[]byte{128, 12, 128, 0}, 4, 2, "0.12", 2, nil}, + {[]byte{128, 0, 0, 128, 0}, 5, 0, "0", 3, nil}, + {[]byte{128, 0, 0, 123, 128, 0}, 7, 3, "0.123", 4, nil}, + {[]byte{128, 0, 0, 0, 12, 128, 0}, 10, 2, "0.12", 5, nil}, + {[]byte{128, 0, 0, 0, 0, 123, 128, 0}, 10, 3, "0.123", 6, nil}, + {[]byte{128, 0, 0, 0, 0, 12, 128, 7}, 13, 2, "0.12", 6, nil}, + {[]byte{128, 7, 91, 178, 144, 1, 129, 205, 128, 0}, 15, 14, "0.12345000098765", 8, nil}, + {[]byte{128, 0, 0, 0, 0, 7, 91, 178, 145, 0, 128, 0}, 20, 10, "0.1234500010", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 48, 57, 128, 0}, 30, 5, "0.12345", 15, nil}, + {[]byte{128, 0, 0, 0, 0, 7, 91, 178, 144, 58, 222, 87, 208, 0, 128, 0}, 30, 20, "0.12345000098765000000", 14, nil}, + {[]byte{128, 0, 0, 7, 91, 178, 144, 58, 222, 87, 208, 0, 0, 0, 0, 30, 0}, 30, 25, "0.1234500009876500000000000", 15, nil}, + {[]byte{128, 0, 128, 0}, 4, 2, "0.00", 2, nil}, + {[]byte{128, 0, 0, 128, 0}, 5, 0, "0", 3, nil}, + {[]byte{128, 0, 0, 0, 128, 0}, 7, 3, "0.000", 4, nil}, + {[]byte{128, 0, 0, 0, 0, 128, 0}, 10, 2, "0.00", 5, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 128, 0}, 10, 3, "0.000", 6, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 127, 255}, 13, 2, "0.00", 6, nil}, + {[]byte{127, 255, 255, 255, 243, 255, 121, 59, 127, 255}, 15, 14, "-0.00000001234500", 8, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 243, 252, 128, 0}, 20, 10, "-0.0000000123", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 255}, 30, 5, "0.00000", 15, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 243, 235, 111, 183, 93, 178, 127, 255}, 30, 20, "-0.00000001234500009877", 14, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 243, 235, 111, 183, 93, 255, 139, 69, 47, 30, 0}, 30, 25, "-0.0000000123450000987650000", 15, nil}, + {[]byte{227, 99, 129, 134}, 4, 2, "99.99", 2, nil}, + {[]byte{129, 134, 159, 167, 15}, 5, 0, "99999", 3, nil}, + {[]byte{167, 15, 3, 231, 133, 245}, 7, 3, "9999.999", 4, nil}, + {[]byte{133, 245, 224, 255, 99, 128, 152}, 10, 2, "99999999.99", 5, nil}, + {[]byte{128, 152, 150, 127, 3, 231, 227, 59}, 10, 3, "9999999.999", 6, nil}, + {[]byte{227, 59, 154, 201, 255, 99, 137, 59}, 13, 2, "99999999999.99", 6, nil}, + {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 137, 59}, 15, 14, "9.99999999999999", 8, nil}, + {[]byte{137, 59, 154, 201, 255, 59, 154, 201, 255, 9, 128, 0}, 20, 10, "9999999999.9999999999", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 4, 210, 29, 205, 139, 148, 0, 195, 80, 137, 59}, 30, 5, "1234500009876.50000", 15, nil}, + {[]byte{137, 59, 154, 201, 255, 59, 154, 201, 255, 59, 154, 201, 255, 99, 129, 134}, 30, 20, "9999999999.99999999999999999999", 14, nil}, + {[]byte{129, 134, 159, 59, 154, 201, 255, 59, 154, 201, 255, 0, 152, 150, 127, 30, 0}, 30, 25, "99999.9999999999999999999999999", 15, nil}, + {[]byte{227, 99, 129, 134}, 4, 2, "99.99", 2, nil}, + {[]byte{129, 134, 159, 167, 15}, 5, 0, "99999", 3, nil}, + {[]byte{167, 15, 3, 231, 133, 245}, 7, 3, "9999.999", 4, nil}, + {[]byte{133, 245, 224, 255, 99, 128, 152}, 10, 2, "99999999.99", 5, nil}, + {[]byte{128, 152, 150, 127, 3, 231, 128, 6}, 10, 3, "9999999.999", 6, nil}, + {[]byte{128, 6, 159, 107, 199, 11, 137, 59}, 13, 2, "111111111.11", 6, nil}, + {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 6}, 15, 14, "9.99999999999999", 8, nil}, + {[]byte{128, 6, 159, 107, 199, 6, 142, 119, 128, 0, 128, 0}, 20, 10, "111111111.1100000000", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 6, 159, 107, 199, 0, 42, 248, 128, 6}, 30, 5, "111111111.11000", 15, nil}, + {[]byte{128, 6, 159, 107, 199, 6, 142, 119, 128, 0, 0, 0, 0, 0, 129, 134}, 30, 20, "111111111.11000000000000000000", 14, nil}, + {[]byte{129, 134, 159, 59, 154, 201, 255, 59, 154, 201, 255, 0, 152, 150, 127, 10, 0}, 30, 25, "99999.9999999999999999999999999", 15, nil}, + {[]byte{128, 1, 128, 0}, 4, 2, "0.01", 2, nil}, + {[]byte{128, 0, 0, 128, 0}, 5, 0, "0", 3, nil}, + {[]byte{128, 0, 0, 10, 128, 0}, 7, 3, "0.010", 4, nil}, + {[]byte{128, 0, 0, 0, 1, 128, 0}, 10, 2, "0.01", 5, nil}, + {[]byte{128, 0, 0, 0, 0, 10, 128, 0}, 10, 3, "0.010", 6, nil}, + {[]byte{128, 0, 0, 0, 0, 1, 128, 0}, 13, 2, "0.01", 6, nil}, + {[]byte{128, 0, 152, 150, 128, 0, 0, 0, 128, 0}, 15, 14, "0.01000000000000", 8, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 152, 150, 128, 0, 128, 0}, 20, 10, "0.0100000000", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 232, 128, 0}, 30, 5, "0.01000", 15, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 152, 150, 128, 0, 0, 0, 0, 0, 128, 0}, 30, 20, "0.01000000000000000000", 14, nil}, + {[]byte{128, 0, 0, 0, 152, 150, 128, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0}, 30, 25, "0.0100000000000000000000000", 15, nil}, + {[]byte{227, 99, 128, 0}, 4, 2, "99.99", 2, nil}, + {[]byte{128, 0, 123, 128, 123}, 5, 0, "123", 3, nil}, + {[]byte{128, 123, 1, 144, 128, 0}, 7, 3, "123.400", 4, nil}, + {[]byte{128, 0, 0, 123, 40, 128, 0}, 10, 2, "123.40", 5, nil}, + {[]byte{128, 0, 0, 123, 1, 144, 128, 0}, 10, 3, "123.400", 6, nil}, + {[]byte{128, 0, 0, 0, 123, 40, 137, 59}, 13, 2, "123.40", 6, nil}, + {[]byte{137, 59, 154, 201, 255, 1, 134, 159, 128, 0}, 15, 14, "9.99999999999999", 8, nil}, + {[]byte{128, 0, 0, 0, 123, 23, 215, 132, 0, 0, 128, 0}, 20, 10, "123.4000000000", 10, nil}, + {[]byte{128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123, 0, 156, 64, 128, 0}, 30, 5, "123.40000", 15, nil}, + {[]byte{128, 0, 0, 0, 123, 23, 215, 132, 0, 0, 0, 0, 0, 0, 128, 0}, 30, 20, "123.40000000000000000000", 14, nil}, + {[]byte{128, 0, 123, 23, 215, 132, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0}, 30, 25, "123.4000000000000000000000000", 15, nil}, + {[]byte{28, 156, 127, 253}, 4, 2, "-99.99", 2, nil}, + {[]byte{127, 253, 204, 125, 205}, 5, 0, "-563", 3, nil}, + {[]byte{125, 205, 253, 187, 127, 255}, 7, 3, "-562.580", 4, nil}, + {[]byte{127, 255, 253, 205, 197, 127, 255}, 10, 2, "-562.58", 5, nil}, + {[]byte{127, 255, 253, 205, 253, 187, 127, 255}, 10, 3, "-562.580", 6, nil}, + {[]byte{127, 255, 255, 253, 205, 197, 118, 196}, 13, 2, "-562.58", 6, nil}, + {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, "-9.99999999999999", 8, nil}, + {[]byte{127, 255, 255, 253, 205, 221, 109, 230, 255, 255, 127, 255}, 20, 10, "-562.5800000000", 10, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 253, 205, 255, 29, 111, 127, 255}, 30, 5, "-562.58000", 15, nil}, + {[]byte{127, 255, 255, 253, 205, 221, 109, 230, 255, 255, 255, 255, 255, 255, 127, 253}, 30, 20, "-562.58000000000000000000", 14, nil}, + {[]byte{127, 253, 205, 221, 109, 230, 255, 255, 255, 255, 255, 255, 255, 255, 255, 13, 0}, 30, 25, "-562.5800000000000000000000000", 15, nil}, + {[]byte{28, 156, 127, 241}, 4, 2, "-99.99", 2, nil}, + {[]byte{127, 241, 140, 113, 140}, 5, 0, "-3699", 3, nil}, + {[]byte{113, 140, 255, 245, 127, 255}, 7, 3, "-3699.010", 4, nil}, + {[]byte{127, 255, 241, 140, 254, 127, 255}, 10, 2, "-3699.01", 5, nil}, + {[]byte{127, 255, 241, 140, 255, 245, 127, 255}, 10, 3, "-3699.010", 6, nil}, + {[]byte{127, 255, 255, 241, 140, 254, 118, 196}, 13, 2, "-3699.01", 6, nil}, + {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, "-9.99999999999999", 8, nil}, + {[]byte{127, 255, 255, 241, 140, 255, 103, 105, 127, 255, 127, 255}, 20, 10, "-3699.0100000000", 10, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 241, 140, 255, 252, 23, 127, 255}, 30, 5, "-3699.01000", 15, nil}, + {[]byte{127, 255, 255, 241, 140, 255, 103, 105, 127, 255, 255, 255, 255, 255, 127, 241}, 30, 20, "-3699.01000000000000000000", 14, nil}, + {[]byte{127, 241, 140, 255, 103, 105, 127, 255, 255, 255, 255, 255, 255, 255, 255, 13, 0}, 30, 25, "-3699.0100000000000000000000000", 15, nil}, + {[]byte{28, 156, 127, 248}, 4, 2, "-99.99", 2, nil}, + {[]byte{127, 248, 99, 120, 99}, 5, 0, "-1948", 3, nil}, + {[]byte{120, 99, 255, 115, 127, 255}, 7, 3, "-1948.140", 4, nil}, + {[]byte{127, 255, 248, 99, 241, 127, 255}, 10, 2, "-1948.14", 5, nil}, + {[]byte{127, 255, 248, 99, 255, 115, 127, 255}, 10, 3, "-1948.140", 6, nil}, + {[]byte{127, 255, 255, 248, 99, 241, 118, 196}, 13, 2, "-1948.14", 6, nil}, + {[]byte{118, 196, 101, 54, 0, 254, 121, 96, 127, 255}, 15, 14, "-9.99999999999999", 8, nil}, + {[]byte{127, 255, 255, 248, 99, 247, 167, 196, 255, 255, 127, 255}, 20, 10, "-1948.1400000000", 10, nil}, + {[]byte{127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 248, 99, 255, 201, 79, 127, 255}, 30, 5, "-1948.14000", 15, nil}, + {[]byte{127, 255, 255, 248, 99, 247, 167, 196, 255, 255, 255, 255, 255, 255, 127, 248}, 30, 20, "-1948.14000000000000000000", 14, nil}, + {[]byte{127, 248, 99, 247, 167, 196, 255, 255, 255, 255, 255, 255, 255, 255, 255, 13, 0}, 30, 25, "-1948.1400000000000000000000000", 15, nil}, } for i, tc := range testcases { - value, pos, err := decodeDecimal(tc.Data, tc.Precision, tc.Decimals) - c.Assert(value, DecodeDecimalsEquals, pos, err, tc.Expected, tc.ExpectedPos, tc.ExpectedErr, i) + value, pos, err := decodeDecimal(tc.Data, tc.Precision, tc.Decimals, false) + expectedFloat, _ := strconv.ParseFloat(tc.Expected, 64) + c.Assert(value.(float64), DecodeDecimalsEquals, pos, err, expectedFloat, tc.ExpectedPos, tc.ExpectedErr, i) + + value, pos, err = decodeDecimal(tc.Data, tc.Precision, tc.Decimals, true) + expectedDecimal, _ := decimal.NewFromString(tc.Expected) + c.Assert(value.(decimal.Decimal), DecodeDecimalsEquals, pos, err, expectedDecimal, tc.ExpectedPos, tc.ExpectedErr, i) } } @@ -386,6 +393,25 @@ func (_ *testDecodeSuite) TestParseRowPanic(c *C) { c.Assert(rows.Rows[0][0], Equals, int32(16270)) } +type simpleDecimalEqualsChecker struct { + *CheckerInfo +} + +var SimpleDecimalEqualsChecker Checker = &simpleDecimalEqualsChecker{ + &CheckerInfo{Name: "Equals", Params: []string{"obtained", "expected"}}, +} + +func (checker *simpleDecimalEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) { + defer func() { + if v := recover(); v != nil { + result = false + error = fmt.Sprint(v) + } + }() + + return params[0].(decimal.Decimal).Equal(params[1].(decimal.Decimal)), "" +} + func (_ *testDecodeSuite) TestParseJson(c *C) { // Table format: // mysql> desc t10; @@ -403,7 +429,8 @@ func (_ *testDecodeSuite) TestParseJson(c *C) { // INSERT INTO `t10` (`c2`) VALUES (1); // INSERT INTO `t10` (`c1`, `c2`) VALUES ('{"key1": "value1", "key2": "value2"}', 1); - + // test json deserialization + // INSERT INTO `t10`(`c1`,`c2`) VALUES ('{"text":"Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo ligula eget dolor. Aenean massa. Cum sociis natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Donec quam felis, ultricies nec, pellentesque eu, pretium quis, sem. Nulla consequat massa quis enim. Donec pede justo, fringilla vel, aliquet nec, vulputate eget, arcu. In enim justo, rhoncus ut, imperdiet a, venenatis vitae, justo. Nullam dictum felis eu pede mollis pretium. Integer tincidunt. Cras dapibus. Vivamus elementum semper nisi. Aenean vulputate eleifend tellus. Aenean leo ligula, porttitor eu, consequat vitae, eleifend ac, enim. Aliquam lorem ante, dapibus in, viverra quis, feugiat a, tellus. Phasellus viverra nulla ut metus varius laoreet. Quisque rutrum. Aenean imperdiet. Etiam ultricies nisi vel augue. Curabitur ullamcorper ultricies nisi. Nam eget dui. Etiam rhoncus. Maecenas tempus, tellus eget condimentum rhoncus, sem quam semper libero, sit amet adipiscing sem neque sed ipsum. Nam quam nunc, blandit vel, luctus pulvinar, hendrerit id, lorem. Maecenas nec odio et ante tincidunt tempus. Donec vitae sapien ut libero venenatis faucibus. Nullam quis ante. Etiam sit amet orci eget eros faucibus tincidunt. Duis leo. Sed fringilla mauris sit amet nibh. Donec sodales sagittis magna. Sed consequat, leo eget bibendum sodales, augue velit cursus nunc, quis gravida magna mi a libero. Fusce vulputate eleifend sapien. Vestibulum purus quam, scelerisque ut, mollis sed, nonummy id, metus. Nullam accumsan lorem in dui. Cras ultricies mi eu turpis hendrerit fringilla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; In ac dui quis mi consectetuer lacinia. Nam pretium turpis et arcu. Duis arcu tortor, suscipit eget, imperdiet nec, imperdiet iaculis, ipsum. Sed aliquam ultrices mauris. Integer ante arcu, accumsan a, consectetuer eget, posuere ut, mauris. Praesent adipiscing. Phasellus ullamcorper ipsum rutrum nunc. Nunc nonummy metus. Vestibulum volutpat pretium libero. Cras id dui. Aenean ut eros et nisl sagittis vestibulum. Nullam nulla eros, ultricies sit amet, nonummy id, imperdiet feugiat, pede. Sed lectus. Donec mollis hendrerit risus. Phasellus nec sem in justo pellentesque facilisis. Etiam imperdiet imperdiet orci. Nunc nec neque. Phasellus leo dolor, tempus non, auctor et, hendrerit quis, nisi. Curabitur ligula sapien, tincidunt non, euismod vitae, posuere imperdiet, leo. Maecenas malesuada. Praesent congue erat at massa. Sed cursus turpis vitae tortor. Donec posuere vulputate arcu. Phasellus accumsan cursus velit. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Sed aliquam, nisi quis porttitor congue, elit erat euismod orci, ac"}',101); tableMapEventData := []byte("m\x00\x00\x00\x00\x00\x01\x00\x04test\x00\x03t10\x00\x02\xf5\xf6\x03\x04\n\x00\x03") tableMapEvent := new(TableMapEvent) @@ -428,4 +455,208 @@ func (_ *testDecodeSuite) TestParseJson(c *C) { c.Assert(err, IsNil) c.Assert(rows.Rows[0][1], Equals, float64(1)) } + + longTbls := [][]byte{ + []byte("m\x00\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfc\xd0\n\x00\x00\x00\x01\x00\xcf\n\v\x00\x04\x00\f\x0f\x00text\xbe\x15Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo ligula eget dolor. Aenean massa. Cum sociis natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Donec quam felis, ultricies nec, pellentesque eu, pretium quis, sem. Nulla consequat massa quis enim. Donec pede justo, fringilla vel, aliquet nec, vulputate eget, arcu. In enim justo, rhoncus ut, imperdiet a, venenatis vitae, justo. Nullam dictum felis eu pede mollis pretium. Integer tincidunt. Cras dapibus. Vivamus elementum semper nisi. Aenean vulputate eleifend tellus. Aenean leo ligula, porttitor eu, consequat vitae, eleifend ac, enim. Aliquam lorem ante, dapibus in, viverra quis, feugiat a, tellus. Phasellus viverra nulla ut metus varius laoreet. Quisque rutrum. Aenean imperdiet. Etiam ultricies nisi vel augue. Curabitur ullamcorper ultricies nisi. Nam eget dui. Etiam rhoncus. Maecenas tempus, tellus eget condimentum rhoncus, sem quam semper libero, sit amet adipiscing sem neque sed ipsum. Nam quam nunc, blandit vel, luctus pulvinar, hendrerit id, lorem. Maecenas nec odio et ante tincidunt tempus. Donec vitae sapien ut libero venenatis faucibus. Nullam quis ante. Etiam sit amet orci eget eros faucibus tincidunt. Duis leo. Sed fringilla mauris sit amet nibh. Donec sodales sagittis magna. Sed consequat, leo eget bibendum sodales, augue velit cursus nunc, quis gravida magna mi a libero. Fusce vulputate eleifend sapien. Vestibulum purus quam, scelerisque ut, mollis sed, nonummy id, metus. Nullam accumsan lorem in dui. Cras ultricies mi eu turpis hendrerit fringilla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; In ac dui quis mi consectetuer lacinia. Nam pretium turpis et arcu. Duis arcu tortor, suscipit eget, imperdiet nec, imperdiet iaculis, ipsum. Sed aliquam ultrices mauris. Integer ante arcu, accumsan a, consectetuer eget, posuere ut, mauris. Praesent adipiscing. Phasellus ullamcorper ipsum rutrum nunc. Nunc nonummy metus. Vestibulum volutpat pretium libero. Cras id dui. Aenean ut eros et nisl sagittis vestibulum. Nullam nulla eros, ultricies sit amet, nonummy id, imperdiet feugiat, pede. Sed lectus. Donec mollis hendrerit risus. Phasellus nec sem in justo pellentesque facilisis. Etiam imperdiet imperdiet orci. Nunc nec neque. Phasellus leo dolor, tempus non, auctor et, hendrerit quis, nisi. Curabitur ligula sapien, tincidunt non, euismod vitae, posuere imperdiet, leo. Maecenas malesuada. Praesent congue erat at massa. Sed cursus turpis vitae tortor. Donec posuere vulputate arcu. Phasellus accumsan cursus velit. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Sed aliquam, nisi quis porttitor congue, elit erat euismod orci, ac\x80\x00\x00\x00e"), + } + + for _, ltbl := range longTbls { + rows.Rows = nil + err = rows.Decode(ltbl) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][1], Equals, float64(101)) + } +} +func (_ *testDecodeSuite) TestParseJsonDecimal(c *C) { + // Table format: + // mysql> desc t10; + // +-------+---------------+------+-----+---------+-------+ + // | Field | Type | Null | Key | Default | Extra | + // +-------+---------------+------+-----+---------+-------+ + // | c1 | json | YES | | NULL | | + // | c2 | decimal(10,0) | YES | | NULL | | + // +-------+---------------+------+-----+---------+-------+ + + // CREATE TABLE `t10` ( + // `c1` json DEFAULT NULL, + // `c2` decimal(10,0) + // ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + + // INSERT INTO `t10` (`c2`) VALUES (1); + // INSERT INTO `t10` (`c1`, `c2`) VALUES ('{"key1": "value1", "key2": "value2"}', 1); + // test json deserialization + // INSERT INTO `t10`(`c1`,`c2`) VALUES ('{"text":"Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo ligula eget dolor. Aenean massa. Cum sociis natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Donec quam felis, ultricies nec, pellentesque eu, pretium quis, sem. Nulla consequat massa quis enim. Donec pede justo, fringilla vel, aliquet nec, vulputate eget, arcu. In enim justo, rhoncus ut, imperdiet a, venenatis vitae, justo. Nullam dictum felis eu pede mollis pretium. Integer tincidunt. Cras dapibus. Vivamus elementum semper nisi. Aenean vulputate eleifend tellus. Aenean leo ligula, porttitor eu, consequat vitae, eleifend ac, enim. Aliquam lorem ante, dapibus in, viverra quis, feugiat a, tellus. Phasellus viverra nulla ut metus varius laoreet. Quisque rutrum. Aenean imperdiet. Etiam ultricies nisi vel augue. Curabitur ullamcorper ultricies nisi. Nam eget dui. Etiam rhoncus. Maecenas tempus, tellus eget condimentum rhoncus, sem quam semper libero, sit amet adipiscing sem neque sed ipsum. Nam quam nunc, blandit vel, luctus pulvinar, hendrerit id, lorem. Maecenas nec odio et ante tincidunt tempus. Donec vitae sapien ut libero venenatis faucibus. Nullam quis ante. Etiam sit amet orci eget eros faucibus tincidunt. Duis leo. Sed fringilla mauris sit amet nibh. Donec sodales sagittis magna. Sed consequat, leo eget bibendum sodales, augue velit cursus nunc, quis gravida magna mi a libero. Fusce vulputate eleifend sapien. Vestibulum purus quam, scelerisque ut, mollis sed, nonummy id, metus. Nullam accumsan lorem in dui. Cras ultricies mi eu turpis hendrerit fringilla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; In ac dui quis mi consectetuer lacinia. Nam pretium turpis et arcu. Duis arcu tortor, suscipit eget, imperdiet nec, imperdiet iaculis, ipsum. Sed aliquam ultrices mauris. Integer ante arcu, accumsan a, consectetuer eget, posuere ut, mauris. Praesent adipiscing. Phasellus ullamcorper ipsum rutrum nunc. Nunc nonummy metus. Vestibulum volutpat pretium libero. Cras id dui. Aenean ut eros et nisl sagittis vestibulum. Nullam nulla eros, ultricies sit amet, nonummy id, imperdiet feugiat, pede. Sed lectus. Donec mollis hendrerit risus. Phasellus nec sem in justo pellentesque facilisis. Etiam imperdiet imperdiet orci. Nunc nec neque. Phasellus leo dolor, tempus non, auctor et, hendrerit quis, nisi. Curabitur ligula sapien, tincidunt non, euismod vitae, posuere imperdiet, leo. Maecenas malesuada. Praesent congue erat at massa. Sed cursus turpis vitae tortor. Donec posuere vulputate arcu. Phasellus accumsan cursus velit. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Sed aliquam, nisi quis porttitor congue, elit erat euismod orci, ac"}',101); + tableMapEventData := []byte("m\x00\x00\x00\x00\x00\x01\x00\x04test\x00\x03t10\x00\x02\xf5\xf6\x03\x04\n\x00\x03") + + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + err := tableMapEvent.Decode(tableMapEventData) + c.Assert(err, IsNil) + + rows := RowsEvent{useDecimal: true} + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + tbls := [][]byte{ + []byte("m\x00\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfd\x80\x00\x00\x00\x01"), + []byte("m\x00\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfc)\x00\x00\x00\x00\x02\x00(\x00\x12\x00\x04\x00\x16\x00\x04\x00\f\x1a\x00\f!\x00key1key2\x06value1\x06value2\x80\x00\x00\x00\x01"), + } + + for _, tbl := range tbls { + rows.Rows = nil + err = rows.Decode(tbl) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][1], SimpleDecimalEqualsChecker, decimal.NewFromFloat(1)) + } + + longTbls := [][]byte{ + []byte("m\x00\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfc\xd0\n\x00\x00\x00\x01\x00\xcf\n\v\x00\x04\x00\f\x0f\x00text\xbe\x15Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo ligula eget dolor. Aenean massa. Cum sociis natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Donec quam felis, ultricies nec, pellentesque eu, pretium quis, sem. Nulla consequat massa quis enim. Donec pede justo, fringilla vel, aliquet nec, vulputate eget, arcu. In enim justo, rhoncus ut, imperdiet a, venenatis vitae, justo. Nullam dictum felis eu pede mollis pretium. Integer tincidunt. Cras dapibus. Vivamus elementum semper nisi. Aenean vulputate eleifend tellus. Aenean leo ligula, porttitor eu, consequat vitae, eleifend ac, enim. Aliquam lorem ante, dapibus in, viverra quis, feugiat a, tellus. Phasellus viverra nulla ut metus varius laoreet. Quisque rutrum. Aenean imperdiet. Etiam ultricies nisi vel augue. Curabitur ullamcorper ultricies nisi. Nam eget dui. Etiam rhoncus. Maecenas tempus, tellus eget condimentum rhoncus, sem quam semper libero, sit amet adipiscing sem neque sed ipsum. Nam quam nunc, blandit vel, luctus pulvinar, hendrerit id, lorem. Maecenas nec odio et ante tincidunt tempus. Donec vitae sapien ut libero venenatis faucibus. Nullam quis ante. Etiam sit amet orci eget eros faucibus tincidunt. Duis leo. Sed fringilla mauris sit amet nibh. Donec sodales sagittis magna. Sed consequat, leo eget bibendum sodales, augue velit cursus nunc, quis gravida magna mi a libero. Fusce vulputate eleifend sapien. Vestibulum purus quam, scelerisque ut, mollis sed, nonummy id, metus. Nullam accumsan lorem in dui. Cras ultricies mi eu turpis hendrerit fringilla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; In ac dui quis mi consectetuer lacinia. Nam pretium turpis et arcu. Duis arcu tortor, suscipit eget, imperdiet nec, imperdiet iaculis, ipsum. Sed aliquam ultrices mauris. Integer ante arcu, accumsan a, consectetuer eget, posuere ut, mauris. Praesent adipiscing. Phasellus ullamcorper ipsum rutrum nunc. Nunc nonummy metus. Vestibulum volutpat pretium libero. Cras id dui. Aenean ut eros et nisl sagittis vestibulum. Nullam nulla eros, ultricies sit amet, nonummy id, imperdiet feugiat, pede. Sed lectus. Donec mollis hendrerit risus. Phasellus nec sem in justo pellentesque facilisis. Etiam imperdiet imperdiet orci. Nunc nec neque. Phasellus leo dolor, tempus non, auctor et, hendrerit quis, nisi. Curabitur ligula sapien, tincidunt non, euismod vitae, posuere imperdiet, leo. Maecenas malesuada. Praesent congue erat at massa. Sed cursus turpis vitae tortor. Donec posuere vulputate arcu. Phasellus accumsan cursus velit. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Sed aliquam, nisi quis porttitor congue, elit erat euismod orci, ac\x80\x00\x00\x00e"), + } + + for _, ltbl := range longTbls { + rows.Rows = nil + err = rows.Decode(ltbl) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][1], SimpleDecimalEqualsChecker, decimal.NewFromFloat(101)) + } +} + +func (_ *testDecodeSuite) TestEnum(c *C) { + // mysql> desc aenum; + // +-------+-------------------------------------------+------+-----+---------+-------+ + // | Field | Type | Null | Key | Default | Extra | + // +-------+-------------------------------------------+------+-----+---------+-------+ + // | id | int(11) | YES | | NULL | | + // | aset | enum('0','1','2','3','4','5','6','7','8') | YES | | NULL | | + // +-------+-------------------------------------------+------+-----+---------+-------+ + // 2 rows in set (0.00 sec) + // + // insert into aenum(id, aset) values(1, '0'); + tableMapEventData := []byte("\x42\x0f\x00\x00\x00\x00\x01\x00\x05\x74\x74\x65\x73\x74\x00\x05") + tableMapEventData = append(tableMapEventData, []byte("\x61\x65\x6e\x75\x6d\x00\x02\x03\xfe\x02\xf7\x01\x03")...) + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + err := tableMapEvent.Decode(tableMapEventData) + c.Assert(err, IsNil) + + rows := new(RowsEvent) + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + data := []byte("\x42\x0f\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfc\x01\x00\x00\x00\x01") + + rows.Rows = nil + err = rows.Decode(data) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][1], Equals, int64(1)) +} + +func (_ *testDecodeSuite) TestMultiBytesEnum(c *C) { + // CREATE TABLE numbers ( + // id int auto_increment, + // num ENUM( '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', '249', '250', '251', '252', '253', '254', '255','256','257' + + // ), + // primary key(id) + // ); + + // + // insert into numbers(num) values ('0'), ('256'); + tableMapEventData := []byte("\x84\x0f\x00\x00\x00\x00\x01\x00\x05\x74\x74\x65\x73\x74\x00\x07") + tableMapEventData = append(tableMapEventData, []byte("\x6e\x75\x6d\x62\x65\x72\x73\x00\x02\x03\xfe\x02\xf7\x02\x02")...) + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + err := tableMapEvent.Decode(tableMapEventData) + c.Assert(err, IsNil) + + rows := new(RowsEvent) + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + data := []byte("\x84\x0f\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfc\x01\x00\x00\x00\x01\x00\xfc\x02\x00\x00\x00\x01\x01") + + rows.Rows = nil + err = rows.Decode(data) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][1], Equals, int64(1)) + c.Assert(rows.Rows[1][1], Equals, int64(257)) +} + +func (_ *testDecodeSuite) TestSet(c *C) { + // mysql> desc aset; + // +--------+---------------------------------------------------------------------------------------+------+-----+---------+-------+ + // | Field | Type | Null | Key | Default | Extra | + // +--------+---------------------------------------------------------------------------------------+------+-----+---------+-------+ + // | id | int(11) | YES | | NULL | | + // | region | set('1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18') | YES | | NULL | | + // +--------+---------------------------------------------------------------------------------------+------+-----+---------+-------+ + // 2 rows in set (0.00 sec) + // + // insert into aset(id, region) values(1, '1,3'); + + tableMapEventData := []byte("\xe7\x0e\x00\x00\x00\x00\x01\x00\x05\x74\x74\x65\x73\x74\x00\x04") + tableMapEventData = append(tableMapEventData, []byte("\x61\x73\x65\x74\x00\x02\x03\xfe\x02\xf8\x03\x03")...) + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + err := tableMapEvent.Decode(tableMapEventData) + c.Assert(err, IsNil) + + rows := new(RowsEvent) + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + data := []byte("\xe7\x0e\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfc\x01\x00\x00\x00\x05\x00\x00") + + rows.Rows = nil + err = rows.Decode(data) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][1], Equals, int64(5)) +} + +func (_ *testDecodeSuite) TestJsonNull(c *C) { + // Table: + // desc hj_order_preview + // +------------------+------------+------+-----+-------------------+----------------+ + // | Field | Type | Null | Key | Default | Extra | + // +------------------+------------+------+-----+-------------------+----------------+ + // | id | int(13) | NO | PRI | | auto_increment | + // | buyer_id | bigint(13) | NO | | | | + // | order_sn | bigint(13) | NO | | | | + // | order_detail | json | NO | | | | + // | is_del | tinyint(1) | NO | | 0 | | + // | add_time | int(13) | NO | | | | + // | last_update_time | timestamp | NO | | CURRENT_TIMESTAMP | | + // +------------------+------------+------+-----+-------------------+----------------+ + // insert into hj_order_preview + // (id, buyer_id, order_sn, is_del, add_time, last_update_time) + // values (1, 95891865464386, 13376222192996417, 0, 1479983995, 1479983995) + + tableMapEventData := []byte("r\x00\x00\x00\x00\x00\x01\x00\x04test\x00\x10hj_order_preview\x00\a\x03\b\b\xf5\x01\x03\x11\x02\x04\x00\x00") + + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + err := tableMapEvent.Decode(tableMapEventData) + c.Assert(err, IsNil) + + rows := new(RowsEvent) + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + data := + []byte("r\x00\x00\x00\x00\x00\x01\x00\x02\x00\a\xff\x80\x01\x00\x00\x00B\ue4d06W\x00\x00A\x10@l\x9a\x85/\x00\x00\x00\x00\x00\x00{\xc36X\x00\x00\x00\x00") + + rows.Rows = nil + err = rows.Decode(data) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][3], HasLen, 0) } diff --git a/vendor/github.com/siddontang/go-mysql/replication/time.go b/vendor/github.com/siddontang/go-mysql/replication/time.go new file mode 100644 index 0000000..bd27c4e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/replication/time.go @@ -0,0 +1,49 @@ +package replication + +import ( + "fmt" + "strings" + "time" +) + +var ( + fracTimeFormat []string +) + +// fracTime is a help structure wrapping Golang Time. +type fracTime struct { + time.Time + + // Dec must in [0, 6] + Dec int + + timestampStringLocation *time.Location +} + +func (t fracTime) String() string { + tt := t.Time + if t.timestampStringLocation != nil { + tt = tt.In(t.timestampStringLocation) + } + return tt.Format(fracTimeFormat[t.Dec]) +} + +func formatZeroTime(frac int, dec int) string { + if dec == 0 { + return "0000-00-00 00:00:00" + } + + s := fmt.Sprintf("0000-00-00 00:00:00.%06d", frac) + + // dec must < 6, if frac is 924000, but dec is 3, we must output 924 here. + return s[0 : len(s)-(6-dec)] +} + +func init() { + fracTimeFormat = make([]string, 7) + fracTimeFormat[0] = "2006-01-02 15:04:05" + + for i := 1; i <= 6; i++ { + fracTimeFormat[i] = fmt.Sprintf("2006-01-02 15:04:05.%s", strings.Repeat("0", i)) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/replication/time_test.go b/vendor/github.com/siddontang/go-mysql/replication/time_test.go new file mode 100644 index 0000000..3a06aaf --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/replication/time_test.go @@ -0,0 +1,70 @@ +package replication + +import ( + "time" + + . "github.com/pingcap/check" +) + +type testTimeSuite struct{} + +var _ = Suite(&testTimeSuite{}) + +func (s *testTimeSuite) TestTime(c *C) { + tbls := []struct { + year int + month int + day int + hour int + min int + sec int + microSec int + frac int + expected string + }{ + {2000, 1, 1, 1, 1, 1, 1, 0, "2000-01-01 01:01:01"}, + {2000, 1, 1, 1, 1, 1, 1, 1, "2000-01-01 01:01:01.0"}, + {2000, 1, 1, 1, 1, 1, 1, 6, "2000-01-01 01:01:01.000001"}, + } + + for _, t := range tbls { + t1 := fracTime{time.Date(t.year, time.Month(t.month), t.day, t.hour, t.min, t.sec, t.microSec*1000, time.UTC), t.frac, nil} + c.Assert(t1.String(), Equals, t.expected) + } + + zeroTbls := []struct { + frac int + dec int + expected string + }{ + {0, 1, "0000-00-00 00:00:00.0"}, + {1, 1, "0000-00-00 00:00:00.0"}, + {123, 3, "0000-00-00 00:00:00.000"}, + {123000, 3, "0000-00-00 00:00:00.123"}, + {123, 6, "0000-00-00 00:00:00.000123"}, + {123000, 6, "0000-00-00 00:00:00.123000"}, + } + + for _, t := range zeroTbls { + c.Assert(formatZeroTime(t.frac, t.dec), Equals, t.expected) + } +} + +func (s *testTimeSuite) TestTimeStringLocation(c *C) { + t := fracTime{ + time.Date(2018, time.Month(7), 30, 10, 0, 0, 0, time.FixedZone("EST", -5*3600)), + 0, + nil, + } + + c.Assert(t.String(), Equals, "2018-07-30 10:00:00") + + t = fracTime{ + time.Date(2018, time.Month(7), 30, 10, 0, 0, 0, time.FixedZone("EST", -5*3600)), + 0, + time.UTC, + } + c.Assert(t.String(), Equals, "2018-07-30 15:00:00") +} + +var _ = Suite(&testTimeSuite{}) diff --git a/vendor/github.com/siddontang/go-mysql/schema/schema.go b/vendor/github.com/siddontang/go-mysql/schema/schema.go index 86d2128..c98b9ac 100644 --- a/vendor/github.com/siddontang/go-mysql/schema/schema.go +++ b/vendor/github.com/siddontang/go-mysql/schema/schema.go @@ -5,6 +5,7 @@ package schema import ( + "database/sql" "fmt" "strings" @@ -12,6 +13,11 @@ import ( "github.com/siddontang/go-mysql/mysql" ) +var ErrTableNotExist = errors.New("table is not exist") +var ErrMissingTableMeta = errors.New("missing table meta") +var HAHealthCheckSchema = "mysql.ha_health_check" + +// Different column type const ( TYPE_NUMBER = iota + 1 // tinyint, smallint, mediumint, int, bigint, year TYPE_FLOAT // float, double @@ -24,12 +30,16 @@ const ( TYPE_TIME // time TYPE_BIT // bit TYPE_JSON // json + TYPE_DECIMAL // decimal ) type TableColumn struct { Name string Type int + Collation string + RawType string IsAuto bool + IsUnsigned bool EnumValues []string SetValues []string } @@ -47,22 +57,24 @@ type Table struct { Columns []TableColumn Indexes []*Index PKColumns []int + + UnsignedColumns []int } func (ta *Table) String() string { return fmt.Sprintf("%s.%s", ta.Schema, ta.Name) } -func (ta *Table) AddColumn(name string, columnType string, extra string) { +func (ta *Table) AddColumn(name string, columnType string, collation string, extra string) { index := len(ta.Columns) - ta.Columns = append(ta.Columns, TableColumn{Name: name}) + ta.Columns = append(ta.Columns, TableColumn{Name: name, Collation: collation}) + ta.Columns[index].RawType = columnType - if strings.Contains(columnType, "int") || strings.HasPrefix(columnType, "year") { - ta.Columns[index].Type = TYPE_NUMBER - } else if strings.HasPrefix(columnType, "float") || - strings.HasPrefix(columnType, "double") || - strings.HasPrefix(columnType, "decimal") { + if strings.HasPrefix(columnType, "float") || + strings.HasPrefix(columnType, "double") { ta.Columns[index].Type = TYPE_FLOAT + } else if strings.HasPrefix(columnType, "decimal") { + ta.Columns[index].Type = TYPE_DECIMAL } else if strings.HasPrefix(columnType, "enum") { ta.Columns[index].Type = TYPE_ENUM ta.Columns[index].EnumValues = strings.Split(strings.Replace( @@ -93,10 +105,17 @@ func (ta *Table) AddColumn(name string, columnType string, extra string) { ta.Columns[index].Type = TYPE_BIT } else if strings.HasPrefix(columnType, "json") { ta.Columns[index].Type = TYPE_JSON + } else if strings.Contains(columnType, "int") || strings.HasPrefix(columnType, "year") { + ta.Columns[index].Type = TYPE_NUMBER } else { ta.Columns[index].Type = TYPE_STRING } + if strings.Contains(columnType, "unsigned") || strings.Contains(columnType, "zerofill") { + ta.Columns[index].IsUnsigned = true + ta.UnsignedColumns = append(ta.UnsignedColumns, index) + } + if extra == "auto_increment" { ta.Columns[index].IsAuto = true } @@ -142,6 +161,35 @@ func (idx *Index) FindColumn(name string) int { return -1 } +func IsTableExist(conn mysql.Executer, schema string, name string) (bool, error) { + query := fmt.Sprintf("SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '%s' and TABLE_NAME = '%s' LIMIT 1", schema, name) + r, err := conn.Execute(query) + if err != nil { + return false, errors.Trace(err) + } + + return r.RowNumber() == 1, nil +} + +func NewTableFromSqlDB(conn *sql.DB, schema string, name string) (*Table, error) { + ta := &Table{ + Schema: schema, + Name: name, + Columns: make([]TableColumn, 0, 16), + Indexes: make([]*Index, 0, 8), + } + + if err := ta.fetchColumnsViaSqlDB(conn); err != nil { + return nil, errors.Trace(err) + } + + if err := ta.fetchIndexesViaSqlDB(conn); err != nil { + return nil, errors.Trace(err) + } + + return ta, nil +} + func NewTable(conn mysql.Executer, schema string, name string) (*Table, error) { ta := &Table{ Schema: schema, @@ -151,18 +199,18 @@ func NewTable(conn mysql.Executer, schema string, name string) (*Table, error) { } if err := ta.fetchColumns(conn); err != nil { - return nil, err + return nil, errors.Trace(err) } if err := ta.fetchIndexes(conn); err != nil { - return nil, err + return nil, errors.Trace(err) } return ta, nil } func (ta *Table) fetchColumns(conn mysql.Executer) error { - r, err := conn.Execute(fmt.Sprintf("describe `%s`.`%s`", ta.Schema, ta.Name)) + r, err := conn.Execute(fmt.Sprintf("show full columns from `%s`.`%s`", ta.Schema, ta.Name)) if err != nil { return errors.Trace(err) } @@ -170,14 +218,39 @@ func (ta *Table) fetchColumns(conn mysql.Executer) error { for i := 0; i < r.RowNumber(); i++ { name, _ := r.GetString(i, 0) colType, _ := r.GetString(i, 1) - extra, _ := r.GetString(i, 5) + collation, _ := r.GetString(i, 2) + extra, _ := r.GetString(i, 6) - ta.AddColumn(name, colType, extra) + ta.AddColumn(name, colType, collation, extra) } return nil } +func (ta *Table) fetchColumnsViaSqlDB(conn *sql.DB) error { + r, err := conn.Query(fmt.Sprintf("show full columns from `%s`.`%s`", ta.Schema, ta.Name)) + if err != nil { + return errors.Trace(err) + } + + defer r.Close() + + var unusedVal interface{} + unused := &unusedVal + + for r.Next() { + var name, colType, extra string + var collation sql.NullString + err := r.Scan(&name, &colType, &collation, &unused, &unused, &unused, &extra, &unused, &unused) + if err != nil { + return errors.Trace(err) + } + ta.AddColumn(name, colType, collation.String, extra) + } + + return r.Err() +} + func (ta *Table) fetchIndexes(conn mysql.Executer) error { r, err := conn.Execute(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name)) if err != nil { @@ -197,6 +270,87 @@ func (ta *Table) fetchIndexes(conn mysql.Executer) error { currentIndex.AddColumn(colName, cardinality) } + return ta.fetchPrimaryKeyColumns() + +} + +func (ta *Table) fetchIndexesViaSqlDB(conn *sql.DB) error { + r, err := conn.Query(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name)) + if err != nil { + return errors.Trace(err) + } + + defer r.Close() + + var currentIndex *Index + currentName := "" + + var unusedVal interface{} + unused := &unusedVal + + for r.Next() { + var indexName, colName string + var cardinality interface{} + + err := r.Scan( + &unused, + &unused, + &indexName, + &unused, + &colName, + &unused, + &cardinality, + &unused, + &unused, + &unused, + &unused, + &unused, + &unused, + ) + if err != nil { + return errors.Trace(err) + } + + if currentName != indexName { + currentIndex = ta.AddIndex(indexName) + currentName = indexName + } + + c := toUint64(cardinality) + currentIndex.AddColumn(colName, c) + } + + return ta.fetchPrimaryKeyColumns() +} + +func toUint64(i interface{}) uint64 { + switch i := i.(type) { + case int: + return uint64(i) + case int8: + return uint64(i) + case int16: + return uint64(i) + case int32: + return uint64(i) + case int64: + return uint64(i) + case uint: + return uint64(i) + case uint8: + return uint64(i) + case uint16: + return uint64(i) + case uint32: + return uint64(i) + case uint64: + return uint64(i) + } + + return 0 +} + +func (ta *Table) fetchPrimaryKeyColumns() error { if len(ta.Indexes) == 0 { return nil } @@ -213,3 +367,32 @@ func (ta *Table) fetchIndexes(conn mysql.Executer) error { return nil } + +// Get primary keys in one row for a table, a table may use multi fields as the PK +func (ta *Table) GetPKValues(row []interface{}) ([]interface{}, error) { + indexes := ta.PKColumns + if len(indexes) == 0 { + return nil, errors.Errorf("table %s has no PK", ta) + } else if len(ta.Columns) != len(row) { + return nil, errors.Errorf("table %s has %d columns, but row data %v len is %d", ta, + len(ta.Columns), row, len(row)) + } + + values := make([]interface{}, 0, len(indexes)) + + for _, index := range indexes { + values = append(values, row[index]) + } + + return values, nil +} + +// Get term column's value +func (ta *Table) GetColumnValue(column string, row []interface{}) (interface{}, error) { + index := ta.FindColumn(column) + if index == -1 { + return nil, errors.Errorf("table %s has no column name %s", ta, column) + } + + return row[index], nil +} diff --git a/vendor/github.com/siddontang/go-mysql/schema/schema_test.go b/vendor/github.com/siddontang/go-mysql/schema/schema_test.go index 327c622..c5bafe1 100644 --- a/vendor/github.com/siddontang/go-mysql/schema/schema_test.go +++ b/vendor/github.com/siddontang/go-mysql/schema/schema_test.go @@ -1,12 +1,14 @@ package schema import ( + "database/sql" "flag" "fmt" "testing" . "github.com/pingcap/check" "github.com/siddontang/go-mysql/client" + _ "github.com/siddontang/go-mysql/driver" ) // use docker mysql for test @@ -17,7 +19,8 @@ func Test(t *testing.T) { } type schemaTestSuite struct { - conn *client.Conn + conn *client.Conn + sqlDB *sql.DB } var _ = Suite(&schemaTestSuite{}) @@ -26,12 +29,19 @@ func (s *schemaTestSuite) SetUpSuite(c *C) { var err error s.conn, err = client.Connect(fmt.Sprintf("%s:%d", *host, 3306), "root", "", "test") c.Assert(err, IsNil) + + s.sqlDB, err = sql.Open("mysql", fmt.Sprintf("root:@%s:3306", *host)) + c.Assert(err, IsNil) } func (s *schemaTestSuite) TearDownSuite(c *C) { if s.conn != nil { s.conn.Close() } + + if s.sqlDB != nil { + s.sqlDB.Close() + } } func (s *schemaTestSuite) TestSchema(c *C) { @@ -44,10 +54,14 @@ func (s *schemaTestSuite) TestSchema(c *C) { id1 INT, id2 INT, name VARCHAR(256), - e ENUM("a", "b", "c"), + status ENUM('appointing','serving','abnormal','stop','noaftermarket','finish','financial_audit'), se SET('a', 'b', 'c'), f FLOAT, d DECIMAL(2, 1), + uint INT UNSIGNED, + zfint INT ZEROFILL, + name_ucs VARCHAR(256) CHARACTER SET ucs2, + name_utf8 VARCHAR(256) CHARACTER SET utf8, PRIMARY KEY(id2, id), UNIQUE (id1), INDEX name_idx (name) @@ -60,15 +74,25 @@ func (s *schemaTestSuite) TestSchema(c *C) { ta, err := NewTable(s.conn, "test", "schema_test") c.Assert(err, IsNil) - c.Assert(ta.Columns, HasLen, 8) + c.Assert(ta.Columns, HasLen, 12) c.Assert(ta.Indexes, HasLen, 3) c.Assert(ta.PKColumns, DeepEquals, []int{2, 0}) c.Assert(ta.Indexes[0].Columns, HasLen, 2) c.Assert(ta.Indexes[0].Name, Equals, "PRIMARY") c.Assert(ta.Indexes[2].Name, Equals, "name_idx") - c.Assert(ta.Columns[4].EnumValues, DeepEquals, []string{"a", "b", "c"}) + c.Assert(ta.Columns[4].EnumValues, DeepEquals, []string{"appointing", "serving", "abnormal", "stop", "noaftermarket", "finish", "financial_audit"}) c.Assert(ta.Columns[5].SetValues, DeepEquals, []string{"a", "b", "c"}) - c.Assert(ta.Columns[7].Type, Equals, TYPE_FLOAT) + c.Assert(ta.Columns[7].Type, Equals, TYPE_DECIMAL) + c.Assert(ta.Columns[0].IsUnsigned, IsFalse) + c.Assert(ta.Columns[8].IsUnsigned, IsTrue) + c.Assert(ta.Columns[9].IsUnsigned, IsTrue) + c.Assert(ta.Columns[10].Collation, Matches, "^ucs2.*") + c.Assert(ta.Columns[11].Collation, Matches, "^utf8.*") + + taSqlDb, err := NewTableFromSqlDB(s.sqlDB, "test", "schema_test") + c.Assert(err, IsNil) + + c.Assert(taSqlDb, DeepEquals, ta) } func (s *schemaTestSuite) TestQuoteSchema(c *C) { diff --git a/vendor/github.com/siddontang/go-mysql/server/auth.go b/vendor/github.com/siddontang/go-mysql/server/auth.go index b66ea4e..0eb54a6 100644 --- a/vendor/github.com/siddontang/go-mysql/server/auth.go +++ b/vendor/github.com/siddontang/go-mysql/server/auth.go @@ -2,118 +2,173 @@ package server import ( "bytes" - "encoding/binary" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/tls" + "fmt" + "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" ) -func (c *Conn) writeInitialHandshake() error { - capability := CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | - CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | - CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION +var ErrAccessDenied = errors.New("access denied") - data := make([]byte, 4, 128) +func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error { + switch authPluginName { + case AUTH_NATIVE_PASSWORD: + if err := c.acquirePassword(); err != nil { + return err + } + return c.compareNativePasswordAuthData(clientAuthData, c.password) - //min version 10 - data = append(data, 10) + case AUTH_CACHING_SHA2_PASSWORD: + if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { + return err + } + if c.cachingSha2FullAuth { + return c.handleAuthSwitchResponse() + } + return nil - //server version[00] - data = append(data, ServerVersion...) - data = append(data, 0) + case AUTH_SHA256_PASSWORD: + if err := c.acquirePassword(); err != nil { + return err + } + cont, err := c.handlePublicKeyRetrieval(clientAuthData) + if err != nil { + return err + } + if !cont { + return nil + } + return c.compareSha256PasswordAuthData(clientAuthData, c.password) - //connection id - data = append(data, byte(c.connectionID), byte(c.connectionID>>8), byte(c.connectionID>>16), byte(c.connectionID>>24)) - - //auth-plugin-data-part-1 - data = append(data, c.salt[0:8]...) - - //filter [00] - data = append(data, 0) - - //capability flag lower 2 bytes, using default capability here - data = append(data, byte(capability), byte(capability>>8)) - - //charset, utf-8 default - data = append(data, uint8(DEFAULT_COLLATION_ID)) - - //status - data = append(data, byte(c.status), byte(c.status>>8)) - - //below 13 byte may not be used - //capability flag upper 2 bytes, using default capability here - data = append(data, byte(capability>>16), byte(capability>>24)) - - //filter [0x15], for wireshark dump, value is 0x15 - data = append(data, 0x15) - - //reserved 10 [00] - data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - - //auth-plugin-data-part-2 - data = append(data, c.salt[8:]...) - - //filter [00] - data = append(data, 0) - - return c.WritePacket(data) + default: + return errors.Errorf("unknown authentication plugin name '%s'", authPluginName) + } } -func (c *Conn) readHandshakeResponse(password string) error { - data, err := c.ReadPacket() - +func (c *Conn) acquirePassword() error { + password, found, err := c.credentialProvider.GetCredential(c.user) if err != nil { return err } - - pos := 0 - - //capability - c.capability = binary.LittleEndian.Uint32(data[:4]) - pos += 4 - - //skip max packet size - pos += 4 - - //charset, skip, if you want to use another charset, use set names - //c.collation = CollationId(data[pos]) - pos++ - - //skip reserved 23[00] - pos += 23 - - //user name - user := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) - pos += len(user) + 1 - - if c.user != user { - return NewDefaultError(ER_NO_SUCH_USER, user, c.RemoteAddr().String()) + if !found { + return NewDefaultError(ER_NO_SUCH_USER, c.user, c.RemoteAddr().String()) } - - //auth length and auth - authLen := int(data[pos]) - pos++ - auth := data[pos : pos+authLen] - - checkAuth := CalcPassword(c.salt, []byte(password)) - - if !bytes.Equal(auth, checkAuth) { - return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.RemoteAddr().String(), c.user, "Yes") - } - - pos += authLen - - if c.capability|CLIENT_CONNECT_WITH_DB > 0 { - if len(data[pos:]) == 0 { - return nil - } - - db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) - pos += len(db) + 1 - - if err = c.h.UseDB(db); err != nil { - return err - } - } - + c.password = password + return nil +} + +func scrambleValidation(cached, nonce, scramble []byte) bool { + // SHA256(SHA256(SHA256(STORED_PASSWORD)), NONCE) + crypt := sha256.New() + crypt.Write(cached) + crypt.Write(nonce) + message2 := crypt.Sum(nil) + // SHA256(PASSWORD) + if len(message2) != len(scramble) { + return false + } + for i := range message2 { + message2[i] ^= scramble[i] + } + // SHA256(SHA256(PASSWORD) + crypt.Reset() + crypt.Write(message2) + m := crypt.Sum(nil) + return bytes.Equal(m, cached) +} + +func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, password string) error { + if bytes.Equal(CalcPassword(c.salt, []byte(c.password)), clientAuthData) { + return nil + } + return ErrAccessDenied +} + +func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password string) error { + // Empty passwords are not hashed, but sent as empty string + if len(clientAuthData) == 0 { + if password == "" { + return nil + } + return ErrAccessDenied + } + if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok { + if !tlsConn.ConnectionState().HandshakeComplete { + return errors.New("incomplete TSL handshake") + } + // connection is SSL/TLS, client should send plain password + // deal with the trailing \NUL added for plain text password received + if l := len(clientAuthData); l != 0 && clientAuthData[l-1] == 0x00 { + clientAuthData = clientAuthData[:l-1] + } + if bytes.Equal(clientAuthData, []byte(password)) { + return nil + } + return ErrAccessDenied + } else { + // client should send encrypted password + // decrypt + dbytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, (c.serverConf.tlsConfig.Certificates[0].PrivateKey).(*rsa.PrivateKey), clientAuthData, nil) + if err != nil { + return err + } + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(c.salt) + plain[i] ^= c.salt[j] + } + if bytes.Equal(plain, dbytes) { + return nil + } + return ErrAccessDenied + } +} + +func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { + // Empty passwords are not hashed, but sent as empty string + if len(clientAuthData) == 0 { + if err := c.acquirePassword(); err != nil { + return err + } + if c.password == "" { + return nil + } + return ErrAccessDenied + } + // the caching of 'caching_sha2_password' in MySQL, see: https://dev.mysql.com/worklog/task/?id=9591 + if _, ok := c.credentialProvider.(*InMemoryProvider); ok { + // since we have already kept the password in memory and calculate the scramble is not that high of cost, we eliminate + // the caching part. So our server will never ask the client to do a full authentication via RSA key exchange and it appears + // like the auth will always hit the cache. + if err := c.acquirePassword(); err != nil { + return err + } + if bytes.Equal(CalcCachingSha2Password(c.salt, c.password), clientAuthData) { + // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 + return c.writeAuthMoreDataFastAuth() + } + return ErrAccessDenied + } + // other type of credential provider, we use the cache + cached, ok := c.serverConf.cacheShaPassword.Load(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr())) + if ok { + // Scramble validation + if scrambleValidation(cached.([]byte), c.salt, clientAuthData) { + // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 + return c.writeAuthMoreDataFastAuth() + } + return ErrAccessDenied + } + // cache miss, do full auth + if err := c.writeAuthMoreDataFullAuth(); err != nil { + return err + } + c.cachingSha2FullAuth = true return nil } diff --git a/vendor/github.com/siddontang/go-mysql/server/auth_switch_response.go b/vendor/github.com/siddontang/go-mysql/server/auth_switch_response.go new file mode 100644 index 0000000..038acff --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/auth_switch_response.go @@ -0,0 +1,133 @@ +package server + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/tls" + "fmt" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" +) + +func (c *Conn) handleAuthSwitchResponse() error { + authData, err := c.readAuthSwitchRequestResponse() + if err != nil { + return err + } + + switch c.authPluginName { + case AUTH_NATIVE_PASSWORD: + if err := c.acquirePassword(); err != nil { + return err + } + if !bytes.Equal(CalcPassword(c.salt, []byte(c.password)), authData) { + return ErrAccessDenied + } + return nil + + case AUTH_CACHING_SHA2_PASSWORD: + if !c.cachingSha2FullAuth { + // Switched auth method but no MoreData packet send yet + if err := c.compareCacheSha2PasswordAuthData(authData); err != nil { + return err + } else { + if c.cachingSha2FullAuth { + return c.handleAuthSwitchResponse() + } + return nil + } + } + // AuthMoreData packet already sent, do full auth + if err := c.handleCachingSha2PasswordFullAuth(authData); err != nil { + return err + } + c.writeCachingSha2Cache() + return nil + + case AUTH_SHA256_PASSWORD: + cont, err := c.handlePublicKeyRetrieval(authData) + if err != nil { + return err + } + if !cont { + return nil + } + if err := c.acquirePassword(); err != nil { + return err + } + return c.compareSha256PasswordAuthData(authData, c.password) + + default: + return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName) + } +} + +func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { + if err := c.acquirePassword(); err != nil { + return err + } + if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok { + if !tlsConn.ConnectionState().HandshakeComplete { + return errors.New("incomplete TSL handshake") + } + // connection is SSL/TLS, client should send plain password + // deal with the trailing \NUL added for plain text password received + if l := len(authData); l != 0 && authData[l-1] == 0x00 { + authData = authData[:l-1] + } + if bytes.Equal(authData, []byte(c.password)) { + return nil + } + return ErrAccessDenied + } else { + // client either request for the public key or send the encrypted password + if len(authData) == 1 && authData[0] == 0x02 { + // send the public key + if err := c.writeAuthMoreDataPubkey(); err != nil { + return err + } + // read the encrypted password + var err error + if authData, err = c.readAuthSwitchRequestResponse(); err != nil { + return err + } + } + // the encrypted password + // decrypt + dbytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, (c.serverConf.tlsConfig.Certificates[0].PrivateKey).(*rsa.PrivateKey), authData, nil) + if err != nil { + return err + } + plain := make([]byte, len(c.password)+1) + copy(plain, c.password) + for i := range plain { + j := i % len(c.salt) + plain[i] ^= c.salt[j] + } + if bytes.Equal(plain, dbytes) { + return nil + } + return ErrAccessDenied + } +} + +func (c *Conn) writeCachingSha2Cache() { + // write cache + if c.password == "" { + return + } + // SHA256(PASSWORD) + crypt := sha256.New() + crypt.Write([]byte(c.password)) + m1 := crypt.Sum(nil) + // SHA256(SHA256(PASSWORD)) + crypt.Reset() + crypt.Write(m1) + m2 := crypt.Sum(nil) + // caching_sha2_password will maintain an in-memory hash of `user`@`host` => SHA256(SHA256(PASSWORD)) + c.serverConf.cacheShaPassword.Store(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr()), m2) +} diff --git a/vendor/github.com/siddontang/go-mysql/server/caching_sha2_cache_test.go b/vendor/github.com/siddontang/go-mysql/server/caching_sha2_cache_test.go new file mode 100644 index 0000000..a8139eb --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/caching_sha2_cache_test.go @@ -0,0 +1,233 @@ +package server + +import ( + "database/sql" + "fmt" + "net" + "strings" + "sync" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/juju/errors" + . "github.com/pingcap/check" + "github.com/siddontang/go-log/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/test_util/test_keys" +) + +var delay = 50 + +// test caching for 'caching_sha2_password' +// NOTE the idea here is to plugin a throttled credential provider so that the first connection (cache miss) will take longer time +// than the second connection (cache hit). Remember to set the password for MySQL user otherwise it won't cache empty password. +func TestCachingSha2Cache(t *testing.T) { + log.SetLevel(log.LevelDebug) + + remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50} + remoteProvider.AddUser(*testUser, *testPassword) + cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) + + // no TLS + Suite(&cacheTestSuite{ + server: cacheServer, + credProvider: remoteProvider, + tlsPara: "false", + }) + + TestingT(t) +} + +func TestCachingSha2CacheTLS(t *testing.T) { + log.SetLevel(log.LevelDebug) + + remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50} + remoteProvider.AddUser(*testUser, *testPassword) + cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) + + // TLS + Suite(&cacheTestSuite{ + server: cacheServer, + credProvider: remoteProvider, + tlsPara: "skip-verify", + }) + + TestingT(t) +} + +type RemoteThrottleProvider struct { + *InMemoryProvider + delay int // in milliseconds +} + +func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) { + time.Sleep(time.Millisecond * time.Duration(m.delay)) + return m.InMemoryProvider.GetCredential(username) +} + +type cacheTestSuite struct { + server *Server + credProvider CredentialProvider + tlsPara string + + db *sql.DB + + l net.Listener +} + +func (s *cacheTestSuite) SetUpSuite(c *C) { + var err error + + s.l, err = net.Listen("tcp", *testAddr) + c.Assert(err, IsNil) + + go s.onAccept(c) + + time.Sleep(30 * time.Millisecond) +} + +func (s *cacheTestSuite) TearDownSuite(c *C) { + if s.l != nil { + s.l.Close() + } +} + +func (s *cacheTestSuite) onAccept(c *C) { + for { + conn, err := s.l.Accept() + if err != nil { + return + } + + go s.onConn(conn, c) + } +} + +func (s *cacheTestSuite) onConn(conn net.Conn, c *C) { + //co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) + co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testCacheHandler{s}) + c.Assert(err, IsNil) + for { + err = co.HandleCommand() + if err != nil { + return + } + } +} + +func (s *cacheTestSuite) runSelect(c *C) { + var a int64 + var b string + + err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=1").Scan(&a, &b) + c.Assert(err, IsNil) + c.Assert(a, Equals, int64(1)) + c.Assert(b, Equals, "hello world") +} + +func (s *cacheTestSuite) TestCache(c *C) { + // first connection + t1 := time.Now() + var err error + s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara)) + c.Assert(err, IsNil) + s.db.SetMaxIdleConns(4) + s.runSelect(c) + t2 := time.Now() + + d1 := int(t2.Sub(t1).Nanoseconds() / 1e6) + //log.Debugf("first connection took %d milliseconds", d1) + + c.Assert(d1, GreaterEqual, delay) + + if s.db != nil { + s.db.Close() + } + + // second connection + t3 := time.Now() + s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara)) + c.Assert(err, IsNil) + s.db.SetMaxIdleConns(4) + s.runSelect(c) + t4 := time.Now() + + d2 := int(t4.Sub(t3).Nanoseconds() / 1e6) + //log.Debugf("second connection took %d milliseconds", d2) + + c.Assert(d2, Less, delay) + if s.db != nil { + s.db.Close() + } + + s.server.cacheShaPassword = &sync.Map{} +} + +type testCacheHandler struct { + s *cacheTestSuite +} + +func (h *testCacheHandler) UseDB(dbName string) error { + return nil +} + +func (h *testCacheHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { + ss := strings.Split(query, " ") + switch strings.ToLower(ss[0]) { + case "select": + var r *mysql.Resultset + var err error + //for handle go mysql driver select @@max_allowed_packet + if strings.Contains(strings.ToLower(query), "max_allowed_packet") { + r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ + {mysql.MaxPayloadLen}, + }, binary) + } else { + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ + {1, "hello world"}, + }, binary) + } + + if err != nil { + return nil, errors.Trace(err) + } else { + return &mysql.Result{0, 0, 0, r}, nil + } + case "insert": + return &mysql.Result{0, 1, 0, nil}, nil + case "delete": + return &mysql.Result{0, 0, 1, nil}, nil + case "update": + return &mysql.Result{0, 0, 1, nil}, nil + case "replace": + return &mysql.Result{0, 0, 1, nil}, nil + default: + return nil, fmt.Errorf("invalid query %s", query) + } + + return nil, nil +} + +func (h *testCacheHandler) HandleQuery(query string) (*mysql.Result, error) { + return h.handleQuery(query, false) +} + +func (h *testCacheHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + return nil, nil +} +func (h *testCacheHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) { + return 0, 0, nil, nil +} + +func (h *testCacheHandler) HandleStmtClose(context interface{}) error { + return nil +} + +func (h *testCacheHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) { + return h.handleQuery(query, true) +} + +func (h *testCacheHandler) HandleOtherCommand(cmd byte, data []byte) error { + return mysql.NewError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("command %d is not supported now", cmd)) +} diff --git a/vendor/github.com/siddontang/go-mysql/server/command.go b/vendor/github.com/siddontang/go-mysql/server/command.go index fb7dcdf..6c8d13a 100644 --- a/vendor/github.com/siddontang/go-mysql/server/command.go +++ b/vendor/github.com/siddontang/go-mysql/server/command.go @@ -11,8 +11,8 @@ import ( type Handler interface { //handle COM_INIT_DB command, you can check whether the dbName is valid, or other. UseDB(dbName string) error - //handle COM_QUERY comamnd, like SELECT, INSERT, UPDATE, etc... - //If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the repsonse, otherwise, we will send Result + //handle COM_QUERY command, like SELECT, INSERT, UPDATE, etc... + //If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the response, otherwise, we will send Result HandleQuery(query string) (*Result, error) //handle COM_FILED_LIST command HandleFieldList(table string, fieldWildcard string) ([]*Field, error) @@ -25,6 +25,9 @@ type Handler interface { //handle COM_STMT_CLOSE, context is the previous one set in prepare //this handler has no response HandleStmtClose(context interface{}) error + //handle any other command that is not currently handled by the library, + //default implementation for this method will return an ER_UNKNOWN_ERROR + HandleOtherCommand(cmd byte, data []byte) error } func (c *Conn) HandleCommand() error { @@ -119,8 +122,7 @@ func (c *Conn) dispatch(data []byte) interface{} { return r } default: - msg := fmt.Sprintf("command %d is not supported now", cmd) - return NewError(ER_UNKNOWN_ERROR, msg) + return c.h.HandleOtherCommand(cmd, data) } return fmt.Errorf("command %d is not handled correctly", cmd) @@ -149,3 +151,10 @@ func (h EmptyHandler) HandleStmtExecute(context interface{}, query string, args func (h EmptyHandler) HandleStmtClose(context interface{}) error { return nil } + +func (h EmptyHandler) HandleOtherCommand(cmd byte, data []byte) error { + return NewError( + ER_UNKNOWN_ERROR, + fmt.Sprintf("command %d is not supported now", cmd), + ) +} diff --git a/vendor/github.com/siddontang/go-mysql/server/command_test.go b/vendor/github.com/siddontang/go-mysql/server/command_test.go new file mode 100644 index 0000000..34b034e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/command_test.go @@ -0,0 +1,4 @@ +package server + +// Ensure EmptyHandler implements Handler interface or cause compile time error +var _ Handler = EmptyHandler{} diff --git a/vendor/github.com/siddontang/go-mysql/server/conn.go b/vendor/github.com/siddontang/go-mysql/server/conn.go index d6ea846..a279b93 100644 --- a/vendor/github.com/siddontang/go-mysql/server/conn.go +++ b/vendor/github.com/siddontang/go-mysql/server/conn.go @@ -15,15 +15,17 @@ import ( type Conn struct { *packet.Conn - capability uint32 + serverConf *Server + capability uint32 + authPluginName string + connectionID uint32 + status uint16 + salt []byte // should be 8 + 12 for auth-plugin-data-part-1 and auth-plugin-data-part-2 - connectionID uint32 - - status uint16 - - user string - - salt []byte + credentialProvider CredentialProvider + user string + password string + cachingSha2FullAuth bool h Handler @@ -35,23 +37,23 @@ type Conn struct { var baseConnID uint32 = 10000 +// create connection with default server settings func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, error) { - c := new(Conn) - - c.h = h - - c.user = user - c.Conn = packet.NewConn(conn) - - c.connectionID = atomic.AddUint32(&baseConnID, 1) - - c.stmts = make(map[uint32]*Stmt) - - c.salt, _ = RandomBuf(20) - + p := NewInMemoryProvider() + p.AddUser(user, password) + salt, _ := RandomBuf(20) + c := &Conn{ + Conn: packet.NewConn(conn), + serverConf: defaultServer, + credentialProvider: p, + h: h, + connectionID: atomic.AddUint32(&baseConnID, 1), + stmts: make(map[uint32]*Stmt), + salt: salt, + } c.closed.Set(false) - if err := c.handshake(password); err != nil { + if err := c.handshake(); err != nil { c.Close() return nil, err } @@ -59,14 +61,38 @@ func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, err return c, nil } -func (c *Conn) handshake(password string) error { +// create connection with customized server settings +func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, h Handler) (*Conn, error) { + salt, _ := RandomBuf(20) + c := &Conn{ + Conn: packet.NewConn(conn), + serverConf: serverConf, + credentialProvider: p, + h: h, + connectionID: atomic.AddUint32(&baseConnID, 1), + stmts: make(map[uint32]*Stmt), + salt: salt, + } + c.closed.Set(false) + + if err := c.handshake(); err != nil { + c.Close() + return nil, err + } + + return c, nil +} + +func (c *Conn) handshake() error { if err := c.writeInitialHandshake(); err != nil { return err } - if err := c.readHandshakeResponse(password); err != nil { + if err := c.readHandshakeResponse(); err != nil { + if err == ErrAccessDenied { + err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.LocalAddr().String(), "Yes") + } c.writeError(err) - return err } diff --git a/vendor/github.com/siddontang/go-mysql/server/credential_provider.go b/vendor/github.com/siddontang/go-mysql/server/credential_provider.go new file mode 100644 index 0000000..3d44eb0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/credential_provider.go @@ -0,0 +1,45 @@ +package server + +import "sync" + +// interface for user credential provider +// hint: can be extended for more functionality +// =================================IMPORTANT NOTE=============================== +// if the password in a third-party credential provider could be updated at runtime, we have to invalidate the caching +// for 'caching_sha2_password' by calling 'func (s *Server)InvalidateCache(string, string)'. +type CredentialProvider interface { + // check if the user exists + CheckUsername(username string) (bool, error) + // get user credential + GetCredential(username string) (password string, found bool, err error) +} + +func NewInMemoryProvider() *InMemoryProvider { + return &InMemoryProvider{ + userPool: sync.Map{}, + } +} + +// implements a in memory credential provider +type InMemoryProvider struct { + userPool sync.Map // username -> password +} + +func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error) { + _, ok := m.userPool.Load(username) + return ok, nil +} + +func (m *InMemoryProvider) GetCredential(username string) (password string, found bool, err error) { + v, ok := m.userPool.Load(username) + if !ok { + return "", false, nil + } + return v.(string), true, nil +} + +func (m *InMemoryProvider) AddUser(username, password string) { + m.userPool.Store(username, password) +} + +type Provider InMemoryProvider diff --git a/vendor/github.com/siddontang/go-mysql/server/example/server_example.go b/vendor/github.com/siddontang/go-mysql/server/example/server_example.go new file mode 100644 index 0000000..1efa1a3 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/example/server_example.go @@ -0,0 +1,51 @@ +package main + +import ( + "net" + + "github.com/siddontang/go-log/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/server" + "github.com/siddontang/go-mysql/test_util/test_keys" + + "crypto/tls" + "time" +) + +type RemoteThrottleProvider struct { + *server.InMemoryProvider + delay int // in milliseconds +} + +func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) { + time.Sleep(time.Millisecond * time.Duration(m.delay)) + return m.InMemoryProvider.GetCredential(username) +} + +func main() { + l, _ := net.Listen("tcp", "127.0.0.1:3306") + // user either the in-memory credential provider or the remote credential provider (you can implement your own) + //inMemProvider := server.NewInMemoryProvider() + //inMemProvider.AddUser("root", "123") + remoteProvider := &RemoteThrottleProvider{server.NewInMemoryProvider(), 10 + 50} + remoteProvider.AddUser("root", "123") + var tlsConf = server.NewServerTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, tls.VerifyClientCertIfGiven) + for { + c, _ := l.Accept() + go func() { + // Create a connection with user root and an empty password. + // You can use your own handler to handle command here. + svr := server.NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) + conn, err := server.NewCustomizedConn(c, svr, remoteProvider, server.EmptyHandler{}) + + if err != nil { + log.Errorf("Connection error: %v", err) + return + } + + for { + conn.HandleCommand() + } + }() + } +} diff --git a/vendor/github.com/siddontang/go-mysql/server/handshake_resp.go b/vendor/github.com/siddontang/go-mysql/server/handshake_resp.go new file mode 100644 index 0000000..79af6f2 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/handshake_resp.go @@ -0,0 +1,190 @@ +package server + +import ( + "bytes" + "crypto/tls" + "encoding/binary" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" +) + +func (c *Conn) readHandshakeResponse() error { + data, pos, err := c.readFirstPart() + if err != nil { + return err + } + if pos, err = c.readUserName(data, pos); err != nil { + return err + } + authData, authLen, pos, err := c.readAuthData(data, pos) + if err != nil { + return err + } + + pos += authLen + + if pos, err = c.readDb(data, pos); err != nil { + return err + } + + pos = c.readPluginName(data, pos) + + cont, err := c.handleAuthMatch(authData, pos) + if err != nil { + return err + } + if !cont { + return nil + } + + // ignore connect attrs for now, the proxy does not support passing attrs to actual MySQL server + + // try to authenticate the client + return c.compareAuthData(c.authPluginName, authData) +} + +func (c *Conn) readFirstPart() ([]byte, int, error) { + data, err := c.ReadPacket() + if err != nil { + return nil, 0, err + } + + pos := 0 + + // check CLIENT_PROTOCOL_41 + if uint32(binary.LittleEndian.Uint16(data[:2]))&CLIENT_PROTOCOL_41 == 0 { + return nil, 0, errors.New("CLIENT_PROTOCOL_41 compatible client is required") + } + + //capability + c.capability = binary.LittleEndian.Uint32(data[:4]) + if c.capability&CLIENT_SECURE_CONNECTION == 0 { + return nil, 0, errors.New("CLIENT_SECURE_CONNECTION compatible client is required") + } + pos += 4 + + //skip max packet size + pos += 4 + + //charset, skip, if you want to use another charset, use set names + //c.collation = CollationId(data[pos]) + pos++ + + //skip reserved 23[00] + pos += 23 + + // is this a SSLRequest packet? + if len(data) == (4 + 4 + 1 + 23) { + if c.serverConf.capability&CLIENT_SSL == 0 { + return nil, 0, errors.Errorf("The host '%s' does not support SSL connections", c.RemoteAddr().String()) + } + // switch to TLS + tlsConn := tls.Server(c.Conn.Conn, c.serverConf.tlsConfig) + if err := tlsConn.Handshake(); err != nil { + return nil, 0, err + } + c.Conn.Conn = tlsConn + + // mysql handshake again + return c.readFirstPart() + } + return data, pos, nil +} + +func (c *Conn) readUserName(data []byte, pos int) (int, error) { + //user name + user := string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(user) + 1 + c.user = user + return pos, nil +} + +func (c *Conn) readDb(data []byte, pos int) (int, error) { + if c.capability&CLIENT_CONNECT_WITH_DB != 0 { + if len(data[pos:]) == 0 { + return pos, nil + } + + db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(db) + 1 + + if err := c.h.UseDB(db); err != nil { + return 0, err + } + } + return pos, nil +} + +func (c *Conn) readPluginName(data []byte, pos int) int { + if c.capability&CLIENT_PLUGIN_AUTH != 0 { + c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(c.authPluginName) + } else { + // The method used is Native Authentication if both CLIENT_PROTOCOL_41 and CLIENT_SECURE_CONNECTION are set, + // but CLIENT_PLUGIN_AUTH is not set, so we fallback to 'mysql_native_password' + c.authPluginName = AUTH_NATIVE_PASSWORD + } + return pos +} + +func (c *Conn) readAuthData(data []byte, pos int) ([]byte, int, int, error) { + // length encoded data + var auth []byte + var authLen int + if c.capability&CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 { + authData, isNULL, readBytes, err := LengthEncodedString(data[pos:]) + if err != nil { + return nil, 0, 0, err + } + if isNULL { + // no auth length and no auth data, just \NUL, considered invalid auth data, and reject connection as MySQL does + return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.LocalAddr().String(), c.user, "Yes") + } + auth = authData + authLen = readBytes + } else { + //auth length and auth + authLen = int(data[pos]) + pos++ + auth = data[pos : pos+authLen] + if authLen == 0 { + // skip the next \NUL in case the password is empty + pos++ + } + } + return auth, authLen, pos, nil +} + +// Public Key Retrieval +// See: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html +func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { + // if the client use 'sha256_password' auth method, and request for a public key + // we send back a keyfile with Protocol::AuthMoreData + if c.authPluginName == AUTH_SHA256_PASSWORD && len(authData) == 1 && authData[0] == 0x01 { + if c.serverConf.capability&CLIENT_SSL == 0 { + return false, errors.New("server does not support SSL: CLIENT_SSL not enabled") + } + if err := c.writeAuthMoreDataPubkey(); err != nil { + return false, err + } + + return false, c.handleAuthSwitchResponse() + } + return true, nil +} + +func (c *Conn) handleAuthMatch(authData []byte, pos int) (bool, error) { + // if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet + // to the client to ask the client to switch. + + if c.authPluginName != c.serverConf.defaultAuthMethod { + if err := c.writeAuthSwitchRequest(c.serverConf.defaultAuthMethod); err != nil { + return false, err + } + c.authPluginName = c.serverConf.defaultAuthMethod + // handle AuthSwitchResponse + return false, c.handleAuthSwitchResponse() + } + return true, nil +} diff --git a/vendor/github.com/siddontang/go-mysql/server/initial_handshake.go b/vendor/github.com/siddontang/go-mysql/server/initial_handshake.go new file mode 100644 index 0000000..312ac2b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/initial_handshake.go @@ -0,0 +1,57 @@ +package server + +// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +func (c *Conn) writeInitialHandshake() error { + data := make([]byte, 4) + + //min version 10 + data = append(data, 10) + + //server version[00] + data = append(data, c.serverConf.serverVersion...) + data = append(data, 0x00) + + //connection id + data = append(data, byte(c.connectionID), byte(c.connectionID>>8), byte(c.connectionID>>16), byte(c.connectionID>>24)) + + //auth-plugin-data-part-1 + data = append(data, c.salt[0:8]...) + + //filter 0x00 byte, terminating the first part of a scramble + data = append(data, 0x00) + + defaultFlag := c.serverConf.capability + //capability flag lower 2 bytes, using default capability here + data = append(data, byte(defaultFlag), byte(defaultFlag>>8)) + + //charset + data = append(data, c.serverConf.collationId) + + //status + data = append(data, byte(c.status), byte(c.status>>8)) + + //capability flag upper 2 bytes, using default capability here + data = append(data, byte(defaultFlag>>16), byte(defaultFlag>>24)) + + // server supports CLIENT_PLUGIN_AUTH and CLIENT_SECURE_CONNECTION + data = append(data, byte(8+12+1)) + + //reserved 10 [00] + data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + + //auth-plugin-data-part-2 + data = append(data, c.salt[8:]...) + // second part of the password cipher [mininum 13 bytes], + // where len=MAX(13, length of auth-plugin-data - 8) + // add \NUL to terminate the string + data = append(data, 0x00) + + // auth plugin name + data = append(data, c.serverConf.defaultAuthMethod...) + + // EOF if MySQL version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) + // \NUL otherwise, so we use \NUL + data = append(data, 0) + + return c.WritePacket(data) +} diff --git a/vendor/github.com/siddontang/go-mysql/server/resp.go b/vendor/github.com/siddontang/go-mysql/server/resp.go index 1123032..db86323 100644 --- a/vendor/github.com/siddontang/go-mysql/server/resp.go +++ b/vendor/github.com/siddontang/go-mysql/server/resp.go @@ -62,6 +62,59 @@ func (c *Conn) writeEOF() error { return c.WritePacket(data) } +// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html +func (c *Conn) writeAuthSwitchRequest(newAuthPluginName string) error { + data := make([]byte, 4) + data = append(data, EOF_HEADER) + data = append(data, []byte(newAuthPluginName)...) + data = append(data, 0x00) + rnd, err := RandomBuf(20) + if err != nil { + return err + } + // new auth data + c.salt = rnd + data = append(data, c.salt...) + // the online doc states it's a string.EOF, however, the actual MySQL server add a \NUL to the end, without it, the + // official MySQL client will fail. + data = append(data, 0x00) + return c.WritePacket(data) +} + +// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_response.html +func (c *Conn) readAuthSwitchRequestResponse() ([]byte, error) { + data, err := c.ReadPacket() + if err != nil { + return nil, err + } + if len(data) == 1 && data[0] == 0x00 { + // \NUL + return make([]byte, 0), nil + } + return data, nil +} + +func (c *Conn) writeAuthMoreDataPubkey() error { + data := make([]byte, 4) + data = append(data, MORE_DATE_HEADER) + data = append(data, c.serverConf.pubKey...) + return c.WritePacket(data) +} + +func (c *Conn) writeAuthMoreDataFullAuth() error { + data := make([]byte, 4) + data = append(data, MORE_DATE_HEADER) + data = append(data, CACHE_SHA2_FULL_AUTH) + return c.WritePacket(data) +} + +func (c *Conn) writeAuthMoreDataFastAuth() error { + data := make([]byte, 4) + data = append(data, MORE_DATE_HEADER) + data = append(data, CACHE_SHA2_FAST_AUTH) + return c.WritePacket(data) +} + func (c *Conn) writeResultset(r *Resultset) error { columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) diff --git a/vendor/github.com/siddontang/go-mysql/server/server_conf.go b/vendor/github.com/siddontang/go-mysql/server/server_conf.go new file mode 100644 index 0000000..353595c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/server_conf.go @@ -0,0 +1,103 @@ +package server + +import ( + "crypto/tls" + "fmt" + "sync" + + . "github.com/siddontang/go-mysql/mysql" +) + +var defaultServer = NewDefaultServer() + +// Defines a basic MySQL server with configs. +// +// We do not aim at implementing the whole MySQL connection suite to have the best compatibilities for the clients. +// The MySQL server can be configured to switch auth methods covering 'mysql_old_password', 'mysql_native_password', +// 'mysql_clear_password', 'authentication_windows_client', 'sha256_password', 'caching_sha2_password', etc. +// +// However, since some old auth methods are considered broken with security issues. MySQL major versions like 5.7 and 8.0 default to +// 'mysql_native_password' or 'caching_sha2_password', and most MySQL clients should have already supported at least one of the three auth +// methods 'mysql_native_password', 'caching_sha2_password', and 'sha256_password'. Thus here we will only support these three +// auth methods, and use 'mysql_native_password' as default for maximum compatibility with the clients and leave the other two as +// config options. +// +// The MySQL doc states that 'mysql_old_password' will be used if 'CLIENT_PROTOCOL_41' or 'CLIENT_SECURE_CONNECTION' flag is not set. +// We choose to drop the support for insecure 'mysql_old_password' auth method and require client capability 'CLIENT_PROTOCOL_41' and 'CLIENT_SECURE_CONNECTION' +// are set. Besides, if 'CLIENT_PLUGIN_AUTH' is not set, we fallback to 'mysql_native_password' auth method. +type Server struct { + serverVersion string // e.g. "8.0.12" + protocolVersion int // minimal 10 + capability uint32 // server capability flag + collationId uint8 + defaultAuthMethod string // default authentication method, 'mysql_native_password' + pubKey []byte + tlsConfig *tls.Config + cacheShaPassword *sync.Map // 'user@host' -> SHA256(SHA256(PASSWORD)) +} + +// New mysql server with default settings. +// +// NOTES: +// TLS support will be enabled by default with auto-generated CA and server certificates (however, you can still use +// non-TLS connection). By default, it will verify the client certificate if present. You can enable TLS support on +// the client side without providing a client-side certificate. So only when you need the server to verify client +// identity for maximum security, you need to set a signed certificate for the client. +func NewDefaultServer() *Server { + caPem, caKey := generateCA() + certPem, keyPem := generateAndSignRSACerts(caPem, caKey) + tlsConf := NewServerTLSConfig(caPem, certPem, keyPem, tls.VerifyClientCertIfGiven) + return &Server{ + serverVersion: "5.7.0", + protocolVersion: 10, + capability: CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | + CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_SSL | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, + collationId: DEFAULT_COLLATION_ID, + defaultAuthMethod: AUTH_NATIVE_PASSWORD, + pubKey: getPublicKeyFromCert(certPem), + tlsConfig: tlsConf, + cacheShaPassword: new(sync.Map), + } +} + +// New mysql server with customized settings. +// +// NOTES: +// You can control the authentication methods and TLS settings here. +// For auth method, you can specify one of the supported methods 'mysql_native_password', 'caching_sha2_password', and 'sha256_password'. +// The specified auth method will be enforced by the server in the connection phase. That means, client will be asked to switch auth method +// if the supplied auth method is different from the server default. +// And for TLS support, you can specify self-signed or CA-signed certificates and decide whether the client needs to provide +// a signed or unsigned certificate to provide different level of security. +func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config) *Server { + if !isAuthMethodSupported(defaultAuthMethod) { + panic(fmt.Sprintf("server authentication method '%s' is not supported", defaultAuthMethod)) + } + + //if !isAuthMethodAllowedByServer(defaultAuthMethod, allowedAuthMethods) { + // panic(fmt.Sprintf("default auth method is not one of the allowed auth methods")) + //} + var capFlag = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | + CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + if tlsConfig != nil { + capFlag |= CLIENT_SSL + } + return &Server{ + serverVersion: serverVersion, + protocolVersion: 10, + capability: capFlag, + collationId: collationId, + defaultAuthMethod: defaultAuthMethod, + pubKey: pubKey, + tlsConfig: tlsConfig, + cacheShaPassword: new(sync.Map), + } +} + +func isAuthMethodSupported(authMethod string) bool { + return authMethod == AUTH_NATIVE_PASSWORD || authMethod == AUTH_CACHING_SHA2_PASSWORD || authMethod == AUTH_SHA256_PASSWORD +} + +func (s *Server) InvalidateCache(username string, host string) { + s.cacheShaPassword.Delete(fmt.Sprintf("%s@%s", username, host)) +} diff --git a/vendor/github.com/siddontang/go-mysql/server/server_test.go b/vendor/github.com/siddontang/go-mysql/server/server_test.go index 54bcf05..1f427fd 100644 --- a/vendor/github.com/siddontang/go-mysql/server/server_test.go +++ b/vendor/github.com/siddontang/go-mysql/server/server_test.go @@ -1,6 +1,7 @@ package server import ( + "crypto/tls" "database/sql" "flag" "fmt" @@ -12,110 +13,90 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/juju/errors" . "github.com/pingcap/check" - mysql "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-log/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/test_util/test_keys" ) var testAddr = flag.String("addr", "127.0.0.1:4000", "MySQL proxy server address") var testUser = flag.String("user", "root", "MySQL user") -var testPassword = flag.String("pass", "", "MySQL password") +var testPassword = flag.String("pass", "123456", "MySQL password") var testDB = flag.String("db", "test", "MySQL test database") +var tlsConf = NewServerTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, tls.VerifyClientCertIfGiven) + +func prepareServerConf() []*Server { + // add default server without TLS + var servers = []*Server{ + // with default TLS + NewDefaultServer(), + // for key exchange, CLIENT_SSL must be enabled for the server and if the connection is not secured with TLS + // server permits MYSQL_NATIVE_PASSWORD only + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + // server permits SHA256_PASSWORD only + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_SHA256_PASSWORD, test_keys.PubPem, tlsConf), + // server permits CACHING_SHA2_PASSWORD only + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf), + + // test auth switch: server permits SHA256_PASSWORD only but sent different method MYSQL_NATIVE_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits CACHING_SHA2_PASSWORD only but sent different method MYSQL_NATIVE_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits CACHING_SHA2_PASSWORD only but sent different method SHA256_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_SHA256_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits MYSQL_NATIVE_PASSWORD only but sent different method SHA256_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_SHA256_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits SHA256_PASSWORD only but sent different method CACHING_SHA2_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits MYSQL_NATIVE_PASSWORD only but sent different method CACHING_SHA2_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf), + } + return servers +} + func Test(t *testing.T) { + log.SetLevel(log.LevelDebug) + + // general tests + inMemProvider := NewInMemoryProvider() + inMemProvider.AddUser(*testUser, *testPassword) + + servers := prepareServerConf() + //no TLS + for _, svr := range servers { + Suite(&serverTestSuite{ + server: svr, + credProvider: inMemProvider, + tlsPara: "false", + }) + } + + // TLS if server supports + for _, svr := range servers { + if svr.tlsConfig != nil { + Suite(&serverTestSuite{ + server: svr, + credProvider: inMemProvider, + tlsPara: "skip-verify", + }) + } + } + TestingT(t) } type serverTestSuite struct { + server *Server + credProvider CredentialProvider + + tlsPara string + db *sql.DB l net.Listener } -var _ = Suite(&serverTestSuite{}) - -type testHandler struct { - s *serverTestSuite -} - -func (h *testHandler) UseDB(dbName string) error { - return nil -} - -func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { - ss := strings.Split(query, " ") - switch strings.ToLower(ss[0]) { - case "select": - var r *mysql.Resultset - var err error - //for handle go mysql driver select @@max_allowed_packet - if strings.Contains(strings.ToLower(query), "max_allowed_packet") { - r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ - []interface{}{mysql.MaxPayloadLen}, - }, binary) - } else { - r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ - []interface{}{1, "hello world"}, - }, binary) - } - - if err != nil { - return nil, errors.Trace(err) - } else { - return &mysql.Result{0, 0, 0, r}, nil - } - case "insert": - return &mysql.Result{0, 1, 0, nil}, nil - case "delete": - return &mysql.Result{0, 0, 1, nil}, nil - case "update": - return &mysql.Result{0, 0, 1, nil}, nil - case "replace": - return &mysql.Result{0, 0, 1, nil}, nil - default: - return nil, fmt.Errorf("invalid query %s", query) - } - - return nil, nil -} - -func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) { - return h.handleQuery(query, false) -} - -func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { - return nil, nil -} -func (h *testHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) { - ss := strings.Split(sql, " ") - switch strings.ToLower(ss[0]) { - case "select": - params = 1 - columns = 2 - case "insert": - params = 2 - columns = 0 - case "replace": - params = 2 - columns = 0 - case "update": - params = 1 - columns = 0 - case "delete": - params = 1 - columns = 0 - default: - err = fmt.Errorf("invalid prepare %s", sql) - } - return params, columns, nil, err -} - -func (h *testHandler) HandleStmtClose(context interface{}) error { - return nil -} - -func (h *testHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) { - return h.handleQuery(query, true) -} - func (s *serverTestSuite) SetUpSuite(c *C) { var err error @@ -124,9 +105,9 @@ func (s *serverTestSuite) SetUpSuite(c *C) { go s.onAccept(c) - time.Sleep(500 * time.Millisecond) + time.Sleep(20 * time.Millisecond) - s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", *testUser, *testPassword, *testAddr, *testDB)) + s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara)) c.Assert(err, IsNil) s.db.SetMaxIdleConns(4) @@ -154,9 +135,10 @@ func (s *serverTestSuite) onAccept(c *C) { } func (s *serverTestSuite) onConn(conn net.Conn, c *C) { - co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) + //co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) + co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testHandler{s}) c.Assert(err, IsNil) - + // set SSL if defined for { err = co.HandleCommand() if err != nil { @@ -228,3 +210,91 @@ func (s *serverTestSuite) TestStmtExec(c *C) { i, _ = r.RowsAffected() c.Assert(i, Equals, int64(1)) } + +type testHandler struct { + s *serverTestSuite +} + +func (h *testHandler) UseDB(dbName string) error { + return nil +} + +func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { + ss := strings.Split(query, " ") + switch strings.ToLower(ss[0]) { + case "select": + var r *mysql.Resultset + var err error + //for handle go mysql driver select @@max_allowed_packet + if strings.Contains(strings.ToLower(query), "max_allowed_packet") { + r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ + {mysql.MaxPayloadLen}, + }, binary) + } else { + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ + {1, "hello world"}, + }, binary) + } + + if err != nil { + return nil, errors.Trace(err) + } else { + return &mysql.Result{0, 0, 0, r}, nil + } + case "insert": + return &mysql.Result{0, 1, 0, nil}, nil + case "delete": + return &mysql.Result{0, 0, 1, nil}, nil + case "update": + return &mysql.Result{0, 0, 1, nil}, nil + case "replace": + return &mysql.Result{0, 0, 1, nil}, nil + default: + return nil, fmt.Errorf("invalid query %s", query) + } + + return nil, nil +} + +func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) { + return h.handleQuery(query, false) +} + +func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + return nil, nil +} +func (h *testHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) { + ss := strings.Split(sql, " ") + switch strings.ToLower(ss[0]) { + case "select": + params = 1 + columns = 2 + case "insert": + params = 2 + columns = 0 + case "replace": + params = 2 + columns = 0 + case "update": + params = 1 + columns = 0 + case "delete": + params = 1 + columns = 0 + default: + err = fmt.Errorf("invalid prepare %s", sql) + } + return params, columns, nil, err +} + +func (h *testHandler) HandleStmtClose(context interface{}) error { + return nil +} + +func (h *testHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) { + return h.handleQuery(query, true) +} + +func (h *testHandler) HandleOtherCommand(cmd byte, data []byte) error { + return mysql.NewError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("command %d is not supported now", cmd)) +} \ No newline at end of file diff --git a/vendor/github.com/siddontang/go-mysql/server/ssl.go b/vendor/github.com/siddontang/go-mysql/server/ssl.go new file mode 100644 index 0000000..1f8a9ed --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/ssl.go @@ -0,0 +1,133 @@ +package server + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" +) + +// generate TLS config for server side +// controlling the security level by authType +func NewServerTLSConfig(caPem, certPem, keyPem []byte, authType tls.ClientAuthType) *tls.Config { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPem) { + panic("failed to add ca PEM") + } + + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + panic(err) + } + + config := &tls.Config{ + ClientAuth: authType, + Certificates: []tls.Certificate{cert}, + ClientCAs: pool, + } + return config +} + +// extract RSA public key from certificate +func getPublicKeyFromCert(certPem []byte) []byte { + block, _ := pem.Decode(certPem) + crt, err := x509.ParseCertificate(block.Bytes) + if err != nil { + panic(err) + } + pubKey, err := x509.MarshalPKIXPublicKey(crt.PublicKey.(*rsa.PublicKey)) + if err != nil { + panic(err) + } + return pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubKey}) +} + +// generate and sign RSA certificates with given CA +// see: https://fale.io/blog/2017/06/05/create-a-pki-in-golang/ +func generateAndSignRSACerts(caPem, caKey []byte) ([]byte, []byte) { + // Load CA + catls, err := tls.X509KeyPair(caPem, caKey) + if err != nil { + panic(err) + } + ca, err := x509.ParseCertificate(catls.Certificate[0]) + if err != nil { + panic(err) + } + + // use the CA to sign certificates + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + panic(err) + } + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"ORGANIZATION_NAME"}, + Country: []string{"COUNTRY_CODE"}, + Province: []string{"PROVINCE"}, + Locality: []string{"CITY"}, + StreetAddress: []string{"ADDRESS"}, + PostalCode: []string{"POSTAL_CODE"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + + // sign the certificate + cert_b, err := x509.CreateCertificate(rand.Reader, ca, cert, &priv.PublicKey, catls.PrivateKey) + if err != nil { + panic(err) + } + certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert_b}) + keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return certPem, keyPem +} + +// generate CA in PEM +// see: https://github.com/golang/go/blob/master/src/crypto/tls/generate_cert.go +func generateCA() ([]byte, []byte) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + panic(err) + } + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"ORGANIZATION_NAME"}, + Country: []string{"COUNTRY_CODE"}, + Province: []string{"PROVINCE"}, + Locality: []string{"CITY"}, + StreetAddress: []string{"ADDRESS"}, + PostalCode: []string{"POSTAL_CODE"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, + BasicConstraintsValid: true, + } + + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + panic(err) + } + + caPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + caKey := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return caPem, caKey +} diff --git a/vendor/github.com/siddontang/go-mysql/server/stmt.go b/vendor/github.com/siddontang/go-mysql/server/stmt.go index 7a325d7..9bef23e 100644 --- a/vendor/github.com/siddontang/go-mysql/server/stmt.go +++ b/vendor/github.com/siddontang/go-mysql/server/stmt.go @@ -144,7 +144,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { } paramTypes = data[pos : pos+(paramNum<<1)] - pos += (paramNum << 1) + pos += paramNum << 1 paramValues = data[pos:] } @@ -211,7 +211,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) if isUnsigned { args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) } else { - args[i] = int16((binary.LittleEndian.Uint16(paramValues[pos : pos+2]))) + args[i] = int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) } pos += 2 continue @@ -270,7 +270,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) return ErrMalformPacket } - v, isNull, n, err = LengthEnodedString(paramValues[pos:]) + v, isNull, n, err = LengthEncodedString(paramValues[pos:]) pos += n if err != nil { return errors.Trace(err) @@ -290,7 +290,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) return nil } -// stmt send long data command has no repsonse +// stmt send long data command has no response func (c *Conn) handleStmtSendLongData(data []byte) error { if len(data) < 6 { return nil @@ -340,7 +340,7 @@ func (c *Conn) handleStmtReset(data []byte) (*Result, error) { return &Result{}, nil } -// stmt close command has no repsonse +// stmt close command has no response func (c *Conn) handleStmtClose(data []byte) error { if len(data) < 4 { return nil diff --git a/vendor/github.com/siddontang/go-mysql/test_util/test_keys/keys.go b/vendor/github.com/siddontang/go-mysql/test_util/test_keys/keys.go new file mode 100644 index 0000000..c1049b6 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/test_util/test_keys/keys.go @@ -0,0 +1,85 @@ +package test_keys + +// here we put the testing encryption keys here +// NOTE THIS IS FOR TESTING ONLY, DO NOT USE THEM IN PRODUCTION! + +var PubPem = []byte(`-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsraCori69OXEkA07Ykp2 +Ju+aHz33PqgKj0qbSbPm6ePh2mer2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqq +gM7uv9JoWCONcKo28cPRK7gH7up7nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz +5Bvi9fThcpf9M6qKdNUa2Wd24rM/n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ +ZaEG86NXFpvvnLHRHiih/gXe7nby1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtV +SX4poNq9YrdP4XW9M0N7nocRU0p5aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGT +MwIDAQAB +-----END PUBLIC KEY-----`) + +var CertPem = []byte(`-----BEGIN CERTIFICATE----- +MIIDBjCCAe4CCQDg06wCf7hcuDANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDUyNVoXDTI4MDgxNjA4NDUyNVowRTELMAkG +A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0 +IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALK2gqK4uvTlxJANO2JKdibvmh899z6oCo9Km0mz5unj4dpnq9hljsQuKtcHUcM4 +HXcE06knaJ4TOF7lcsjaqoDO7r/SaFgjjXCqNvHD0Su4B+7qe52BZZTRV1AANP10 +PvebarXSEzaZUCyHHhSF8+Qb4vX04XKX/TOqinTVGtlnduKzP5+qsaFBtpLAw1V0 +At9EQB5BgnTYtdIsmvD4/2WhBvOjVxab75yx0R4oof4F3u528tbEegcWhBtmy2Xd +HI3S+TLljj3kOOdB+pgrVUl+KaDavWK3T+F1vTNDe56HEVNKeWlLy1scul61E0j9 +IkZAu6aRDxtKdl7bKu0BkzMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAma3yFqR7 +xkeaZBg4/1I3jSlaNe5+2JB4iybAkMOu77fG5zytLomTbzdhewsuBwpTVMJdga8T +IdPeIFCin1U+5SkbjSMlpKf+krE+5CyrNJ5jAgO9ATIqx66oCTYXfGlNapGRLfSE +sa0iMqCe/dr4GPU+flW2DZFWiyJVDSF1JjReQnfrWY+SD2SpP/lmlgltnY8MJngd +xBLG5nsZCpUXGB713Q8ZyIm2ThVAMiskcxBleIZDDghLuhGvY/9eFJhZpvOkjWa6 +XGEi4E1G/SA+zVKFl41nHKCdqXdmIOnpcLlFBUVloQok5a95Kqc1TYw3f+WbdFff +99dAgk3gWwWZQA== +-----END CERTIFICATE-----`) + +var KeyPem = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAsraCori69OXEkA07Ykp2Ju+aHz33PqgKj0qbSbPm6ePh2mer +2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqqgM7uv9JoWCONcKo28cPRK7gH7up7 +nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz5Bvi9fThcpf9M6qKdNUa2Wd24rM/ +n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ZaEG86NXFpvvnLHRHiih/gXe7nby +1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtVSX4poNq9YrdP4XW9M0N7nocRU0p5 +aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGTMwIDAQABAoIBAGh1m8hHWCg7gXh9 +838RbRx3IswuKS27hWiaQEiFWmzOIb7KqDy1qAxtu+ayRY1paHegH6QY/+Kd824s +ibpzbgQacJ04/HrAVTVMmQ8Z2VLHoAN7lcPL1bd14aZGaLLZVtDeTDJ413grhxxv +4ho27gcgcbo4Z+rWgk7H2WRPCAGYqWYAycm3yF5vy9QaO6edU+T588YsEQOos5iy +5pVFSGDGZkcUp1ukL3BJYR+jvygn6WPCobQ/LScUdi+ucitaI9i+UdlLokZARVRG +M/msqcTM73thR8yVRcexU6NUDxRBfZ/f7moSAEbBmGDXuxDcIyH9KGMQ2rMtN1X3 +lK8UNwkCgYEA2STJq/IUQHjdqd3Dqh/Q7Zm8/pMWFqLJSkqpnFtFsXPyUOx9zDOy +KqkIfGeyKwvsj9X9BcZ0FUKj9zoct1/WpPY+h7i7+z0MIujBh4AMjAcDrt4o76yK +UHuVmG2xKTdJoAbqOdToQeX6E82Ioal5pbB2W7AbCQScNBPZ52jxgtcCgYEA0rE7 +2dFiRm0YmuszFYxft2+GP6NgP3R2TQNEooi1uCXG2xgwObie1YCHzpZ5CfSqJIxP +XB7DXpIWi7PxJoeai2F83LnmdFz6F1BPRobwDoSFNdaSKLg4Yf856zpgYNKhL1fE +OoOXj4VBWBZh1XDfZV44fgwlMIf7edOF1XOagwUCgYAw953O+7FbdKYwF0V3iOM5 +oZDAK/UwN5eC/GFRVDfcM5RycVJRCVtlSWcTfuLr2C2Jpiz/72fgH34QU3eEVsV1 +v94MBznFB1hESw7ReqvZq/9FoO3EVrl+OtBaZmosLD6bKtQJJJ0Xtz/01UW5hxla +pveZ55XBK9v51nwuNjk4UwKBgHD8fJUllSchUCWb5cwzeAz98Kdl7LJ6uQo5q2/i +EllLYOWThiEeIYdrIuklholRPIDXAaPsF2c6vn5yo+q+o6EFSZlw0+YpCjDAb5Lp +wAh5BprFk6HkkM/0t9Guf4rMyYWC8odSlE9x7YXYkuSMYDCTI4Zs6vCoq7I8PbQn +B4AlAoGAZ6Ee5m/ph5UVp/3+cR6jCY7aHBUU/M3pbJSkVjBW+ymEBVJ6sUdz8k3P +x8BiPEQggNN7faWBqRWP7KXPnDYHh6shYUgPJwI5HX6NE/ZDnnXjeysHRyf0oCo5 +S6tHXwHNKB5HS1c/KDyyNGjP2oi/MF4o/MGWNWEcK6TJA3RGOYM= +-----END RSA PRIVATE KEY-----`) + +var CaPem = []byte(`-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIJANeS1FOzWXlZMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTgwODE2MTUxNDE5WhcNMjEwNjA1MTUxNDE5WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QLsuT1fGf5 +zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eoxrS64GSb +PVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKvlpEAtRAe +4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7IeXUc456 +yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioONaePQKFR +/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABo4GnMIGkMB0GA1UdDgQWBBQgHiwD +00upIbCOunlK4HRw89DhjjB1BgNVHSMEbjBsgBQgHiwD00upIbCOunlK4HRw89Dh +jqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV +BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJANeS1FOzWXlZMAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAFMZFQTFKU5tWIpWh8BbVZeVZcng0Kiq +qwbhVwaTkqtfmbqw8/w+faOWylmLncQEMmgvnUltGMQlQKBwQM2byzPkz9phal3g +uI0JWJYqtcMyIQUB9QbbhrDNC9kdt/ji/x6rrIqzaMRuiBXqH5LQ9h856yXzArqd +cAQGzzYpbUCIv7ciSB93cKkU73fQLZVy5ZBy1+oAa1V9U4cb4G/20/PDmT+G3Gxz +pEjeDKtz8XINoWgA2cSdfAhNZt5vqJaCIZ8qN0z6C7SUKwUBderERUMLUXdhUldC +KTVHyEPvd0aULd5S5vEpKCnHcQmFcLdoN8t9k9pR9ZgwqXbyJHlxWFo= +-----END CERTIFICATE-----`) diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/COPYING b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/COPYING similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/COPYING rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/COPYING diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/toml-test-decoder/COPYING b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/toml-test-decoder/COPYING new file mode 100644 index 0000000..5a8e332 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/toml-test-decoder/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/toml-test-encoder/COPYING b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/toml-test-encoder/COPYING new file mode 100644 index 0000000..5a8e332 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/toml-test-encoder/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/tomlv/COPYING b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/tomlv/COPYING new file mode 100644 index 0000000..5a8e332 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/cmd/tomlv/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/decode.go similarity index 90% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/decode.go index 6c7d398..b0fd51d 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/decode.go @@ -10,7 +10,9 @@ import ( "time" ) -var e = fmt.Errorf +func e(format string, args ...interface{}) error { + return fmt.Errorf("toml: "+format, args...) +} // Unmarshaler is the interface implemented by objects that can unmarshal a // TOML description of themselves. @@ -103,6 +105,13 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error { // This decoder will not handle cyclic types. If a cyclic type is passed, // `Decode` will not terminate. func Decode(data string, v interface{}) (MetaData, error) { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v)) + } + if rv.IsNil() { + return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v)) + } p, err := parse(data) if err != nil { return MetaData{}, err @@ -111,7 +120,7 @@ func Decode(data string, v interface{}) (MetaData, error) { p.mapping, p.types, p.ordered, make(map[string]bool, len(p.ordered)), nil, } - return md, md.unify(p.mapping, rvalue(v)) + return md, md.unify(p.mapping, indirect(rv)) } // DecodeFile is just like Decode, except it will automatically read the @@ -211,7 +220,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error { case reflect.Interface: // we only support empty interfaces. if rv.NumMethod() > 0 { - return e("Unsupported type '%s'.", rv.Kind()) + return e("unsupported type %s", rv.Type()) } return md.unifyAnything(data, rv) case reflect.Float32: @@ -219,13 +228,17 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error { case reflect.Float64: return md.unifyFloat64(data, rv) } - return e("Unsupported type '%s'.", rv.Kind()) + return e("unsupported type %s", rv.Kind()) } func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { tmap, ok := mapping.(map[string]interface{}) if !ok { - return mismatch(rv, "map", mapping) + if mapping == nil { + return nil + } + return e("type mismatch for %s: expected table but found %T", + rv.Type().String(), mapping) } for key, datum := range tmap { @@ -250,14 +263,13 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { md.decoded[md.context.add(key).String()] = true md.context = append(md.context, key) if err := md.unify(datum, subv); err != nil { - return e("Type mismatch for '%s.%s': %s", - rv.Type().String(), f.name, err) + return err } md.context = md.context[0 : len(md.context)-1] } else if f.name != "" { // Bad user! No soup for you! - return e("Field '%s.%s' is unexported, and therefore cannot "+ - "be loaded with reflection.", rv.Type().String(), f.name) + return e("cannot write unexported field %s.%s", + rv.Type().String(), f.name) } } } @@ -267,6 +279,9 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { tmap, ok := mapping.(map[string]interface{}) if !ok { + if tmap == nil { + return nil + } return badtype("map", mapping) } if rv.IsNil() { @@ -292,6 +307,9 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { datav := reflect.ValueOf(data) if datav.Kind() != reflect.Slice { + if !datav.IsValid() { + return nil + } return badtype("slice", data) } sliceLen := datav.Len() @@ -305,12 +323,16 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error { datav := reflect.ValueOf(data) if datav.Kind() != reflect.Slice { + if !datav.IsValid() { + return nil + } return badtype("slice", data) } - sliceLen := datav.Len() - if rv.IsNil() { - rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen)) + n := datav.Len() + if rv.IsNil() || rv.Cap() < n { + rv.Set(reflect.MakeSlice(rv.Type(), n, n)) } + rv.SetLen(n) return md.unifySliceArray(datav, rv) } @@ -365,15 +387,15 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error { // No bounds checking necessary. case reflect.Int8: if num < math.MinInt8 || num > math.MaxInt8 { - return e("Value '%d' is out of range for int8.", num) + return e("value %d is out of range for int8", num) } case reflect.Int16: if num < math.MinInt16 || num > math.MaxInt16 { - return e("Value '%d' is out of range for int16.", num) + return e("value %d is out of range for int16", num) } case reflect.Int32: if num < math.MinInt32 || num > math.MaxInt32 { - return e("Value '%d' is out of range for int32.", num) + return e("value %d is out of range for int32", num) } } rv.SetInt(num) @@ -384,15 +406,15 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error { // No bounds checking necessary. case reflect.Uint8: if num < 0 || unum > math.MaxUint8 { - return e("Value '%d' is out of range for uint8.", num) + return e("value %d is out of range for uint8", num) } case reflect.Uint16: if num < 0 || unum > math.MaxUint16 { - return e("Value '%d' is out of range for uint16.", num) + return e("value %d is out of range for uint16", num) } case reflect.Uint32: if num < 0 || unum > math.MaxUint32 { - return e("Value '%d' is out of range for uint32.", num) + return e("value %d is out of range for uint32", num) } } rv.SetUint(unum) @@ -458,7 +480,7 @@ func rvalue(v interface{}) reflect.Value { // interest to us (like encoding.TextUnmarshaler). func indirect(v reflect.Value) reflect.Value { if v.Kind() != reflect.Ptr { - if v.CanAddr() { + if v.CanSet() { pv := v.Addr() if _, ok := pv.Interface().(TextUnmarshaler); ok { return pv @@ -483,10 +505,5 @@ func isUnifiable(rv reflect.Value) bool { } func badtype(expected string, data interface{}) error { - return e("Expected %s but found '%T'.", expected, data) -} - -func mismatch(user reflect.Value, expected string, data interface{}) error { - return e("Type mismatch for %s. Expected %s but found '%T'.", - user.Type().String(), expected, data) + return e("cannot load TOML value of type %T into a Go %s", data, expected) } diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode_meta.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/decode_meta.go similarity index 99% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode_meta.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/decode_meta.go index ef6f545..b9914a6 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode_meta.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/decode_meta.go @@ -77,9 +77,8 @@ func (k Key) maybeQuoted(i int) string { } if quote { return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\"" - } else { - return k[i] } + return k[i] } func (k Key) add(piece string) Key { diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/doc.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/doc.go similarity index 94% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/doc.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/doc.go index fe26800..b371f39 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/doc.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/doc.go @@ -4,7 +4,7 @@ files via reflection. There is also support for delaying decoding with the Primitive type, and querying the set of keys in a TOML document with the MetaData type. -The specification implemented: https://github.com/mojombo/toml +The specification implemented: https://github.com/toml-lang/toml The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify whether a file is a valid TOML document. It can also be used to print the diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encode.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/encode.go similarity index 85% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encode.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/encode.go index c7e227c..d905c21 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encode.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/encode.go @@ -16,17 +16,17 @@ type tomlEncodeError struct{ error } var ( errArrayMixedElementTypes = errors.New( - "can't encode array with mixed element types") + "toml: cannot encode array with mixed element types") errArrayNilElement = errors.New( - "can't encode array with nil element") + "toml: cannot encode array with nil element") errNonString = errors.New( - "can't encode a map with non-string key type") + "toml: cannot encode a map with non-string key type") errAnonNonStruct = errors.New( - "can't encode an anonymous field that is not a struct") + "toml: cannot encode an anonymous field that is not a struct") errArrayNoTable = errors.New( - "TOML array element can't contain a table") + "toml: TOML array element cannot contain a table") errNoKey = errors.New( - "top-level values must be a Go map or struct") + "toml: top-level values must be Go maps or structs") errAnything = errors.New("") // used in testing ) @@ -148,7 +148,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) { case reflect.Struct: enc.eTable(key, rv) default: - panic(e("Unsupported type for key '%s': %s", key, k)) + panic(e("unsupported type for key '%s': %s", key, k)) } } @@ -160,7 +160,7 @@ func (enc *Encoder) eElement(rv reflect.Value) { // Special case time.Time as a primitive. Has to come before // TextMarshaler below because time.Time implements // encoding.TextMarshaler, but we need to always use UTC. - enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z")) + enc.wf(v.UTC().Format("2006-01-02T15:04:05Z")) return case TextMarshaler: // Special case. Use text marshaler if it's available for this value. @@ -191,7 +191,7 @@ func (enc *Encoder) eElement(rv reflect.Value) { case reflect.String: enc.writeQuoted(rv.String()) default: - panic(e("Unexpected primitive type: %s", rv.Kind())) + panic(e("unexpected primitive type: %s", rv.Kind())) } } @@ -241,7 +241,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { func (enc *Encoder) eTable(key Key, rv reflect.Value) { panicIfInvalidKey(key) if len(key) == 1 { - // Output an extra new line between top-level tables. + // Output an extra newline between top-level tables. // (The newline isn't written if nothing else has been written though.) enc.newline() } @@ -306,19 +306,36 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) { addFields = func(rt reflect.Type, rv reflect.Value, start []int) { for i := 0; i < rt.NumField(); i++ { f := rt.Field(i) - // skip unexporded fields - if f.PkgPath != "" { + // skip unexported fields + if f.PkgPath != "" && !f.Anonymous { continue } frv := rv.Field(i) if f.Anonymous { - frv := eindirect(frv) - t := frv.Type() - if t.Kind() != reflect.Struct { - encPanic(errAnonNonStruct) + t := f.Type + switch t.Kind() { + case reflect.Struct: + // Treat anonymous struct fields with + // tag names as though they are not + // anonymous, like encoding/json does. + if getOptions(f.Tag).name == "" { + addFields(t, frv, f.Index) + continue + } + case reflect.Ptr: + if t.Elem().Kind() == reflect.Struct && + getOptions(f.Tag).name == "" { + if !frv.IsNil() { + addFields(t.Elem(), frv.Elem(), f.Index) + } + continue + } + // Fall through to the normal field encoding logic below + // for non-struct anonymous fields. } - addFields(t, frv, f.Index) - } else if typeIsHash(tomlTypeOfGo(frv)) { + } + + if typeIsHash(tomlTypeOfGo(frv)) { fieldsSub = append(fieldsSub, append(start, f.Index...)) } else { fieldsDirect = append(fieldsDirect, append(start, f.Index...)) @@ -336,18 +353,18 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) { continue } - keyName := sft.Tag.Get("toml") - if keyName == "-" { + opts := getOptions(sft.Tag) + if opts.skip { continue } - if keyName == "" { - keyName = sft.Name + keyName := sft.Name + if opts.name != "" { + keyName = opts.name } - - keyName, opts := getOptions(keyName) - if _, ok := opts["omitempty"]; ok && isEmpty(sf) { + if opts.omitempty && isEmpty(sf) { continue - } else if _, ok := opts["omitzero"]; ok && isZero(sf) { + } + if opts.omitzero && isZero(sf) { continue } @@ -382,9 +399,8 @@ func tomlTypeOfGo(rv reflect.Value) tomlType { case reflect.Array, reflect.Slice: if typeEqual(tomlHash, tomlArrayType(rv)) { return tomlArrayHash - } else { - return tomlArray } + return tomlArray case reflect.Ptr, reflect.Interface: return tomlTypeOfGo(rv.Elem()) case reflect.String: @@ -441,50 +457,51 @@ func tomlArrayType(rv reflect.Value) tomlType { return firstType } -func getOptions(keyName string) (string, map[string]struct{}) { - opts := make(map[string]struct{}) - ss := strings.Split(keyName, ",") - name := ss[0] - if len(ss) > 1 { - for _, opt := range ss { - opts[opt] = struct{}{} +type tagOptions struct { + skip bool // "-" + name string + omitempty bool + omitzero bool +} + +func getOptions(tag reflect.StructTag) tagOptions { + t := tag.Get("toml") + if t == "-" { + return tagOptions{skip: true} + } + var opts tagOptions + parts := strings.Split(t, ",") + opts.name = parts[0] + for _, s := range parts[1:] { + switch s { + case "omitempty": + opts.omitempty = true + case "omitzero": + opts.omitzero = true } } - - return name, opts + return opts } func isZero(rv reflect.Value) bool { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if rv.Int() == 0 { - return true - } + return rv.Int() == 0 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if rv.Uint() == 0 { - return true - } + return rv.Uint() == 0 case reflect.Float32, reflect.Float64: - if rv.Float() == 0.0 { - return true - } + return rv.Float() == 0.0 } - return false } func isEmpty(rv reflect.Value) bool { switch rv.Kind() { - case reflect.String: - if len(strings.TrimSpace(rv.String())) == 0 { - return true - } - case reflect.Array, reflect.Slice, reflect.Map: - if rv.Len() == 0 { - return true - } + case reflect.Array, reflect.Slice, reflect.Map, reflect.String: + return rv.Len() == 0 + case reflect.Bool: + return !rv.Bool() } - return false } diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/encoding_types.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/encoding_types.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/lex.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/lex.go similarity index 62% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/lex.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/lex.go index 2191228..6dee7fc 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/lex.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/lex.go @@ -3,6 +3,7 @@ package toml import ( "fmt" "strings" + "unicode" "unicode/utf8" ) @@ -29,24 +30,28 @@ const ( itemArrayTableEnd itemKeyStart itemCommentStart + itemInlineTableStart + itemInlineTableEnd ) const ( - eof = 0 - tableStart = '[' - tableEnd = ']' - arrayTableStart = '[' - arrayTableEnd = ']' - tableSep = '.' - keySep = '=' - arrayStart = '[' - arrayEnd = ']' - arrayValTerm = ',' - commentStart = '#' - stringStart = '"' - stringEnd = '"' - rawStringStart = '\'' - rawStringEnd = '\'' + eof = 0 + comma = ',' + tableStart = '[' + tableEnd = ']' + arrayTableStart = '[' + arrayTableEnd = ']' + tableSep = '.' + keySep = '=' + arrayStart = '[' + arrayEnd = ']' + commentStart = '#' + stringStart = '"' + stringEnd = '"' + rawStringStart = '\'' + rawStringEnd = '\'' + inlineTableStart = '{' + inlineTableEnd = '}' ) type stateFn func(lx *lexer) stateFn @@ -55,11 +60,18 @@ type lexer struct { input string start int pos int - width int line int state stateFn items chan item + // Allow for backing up up to three runes. + // This is necessary because TOML contains 3-rune tokens (""" and '''). + prevWidths [3]int + nprev int // how many of prevWidths are in use + // If we emit an eof, we can still back up, but it is not OK to call + // next again. + atEOF bool + // A stack of state functions used to maintain context. // The idea is to reuse parts of the state machine in various places. // For example, values can appear at the top level or within arbitrarily @@ -87,7 +99,7 @@ func (lx *lexer) nextItem() item { func lex(input string) *lexer { lx := &lexer{ - input: input + "\n", + input: input, state: lexTop, line: 1, items: make(chan item, 10), @@ -102,7 +114,7 @@ func (lx *lexer) push(state stateFn) { func (lx *lexer) pop() stateFn { if len(lx.stack) == 0 { - return lx.errorf("BUG in lexer: no states to pop.") + return lx.errorf("BUG in lexer: no states to pop") } last := lx.stack[len(lx.stack)-1] lx.stack = lx.stack[0 : len(lx.stack)-1] @@ -124,16 +136,25 @@ func (lx *lexer) emitTrim(typ itemType) { } func (lx *lexer) next() (r rune) { + if lx.atEOF { + panic("next called after EOF") + } if lx.pos >= len(lx.input) { - lx.width = 0 + lx.atEOF = true return eof } if lx.input[lx.pos] == '\n' { lx.line++ } - r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:]) - lx.pos += lx.width + lx.prevWidths[2] = lx.prevWidths[1] + lx.prevWidths[1] = lx.prevWidths[0] + if lx.nprev < 3 { + lx.nprev++ + } + r, w := utf8.DecodeRuneInString(lx.input[lx.pos:]) + lx.prevWidths[0] = w + lx.pos += w return r } @@ -142,9 +163,20 @@ func (lx *lexer) ignore() { lx.start = lx.pos } -// backup steps back one rune. Can be called only once per call of next. +// backup steps back one rune. Can be called only twice between calls to next. func (lx *lexer) backup() { - lx.pos -= lx.width + if lx.atEOF { + lx.atEOF = false + return + } + if lx.nprev < 1 { + panic("backed up too far") + } + w := lx.prevWidths[0] + lx.prevWidths[0] = lx.prevWidths[1] + lx.prevWidths[1] = lx.prevWidths[2] + lx.nprev-- + lx.pos -= w if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' { lx.line-- } @@ -166,9 +198,22 @@ func (lx *lexer) peek() rune { return r } +// skip ignores all input that matches the given predicate. +func (lx *lexer) skip(pred func(rune) bool) { + for { + r := lx.next() + if pred(r) { + continue + } + lx.backup() + lx.ignore() + return + } +} + // errorf stops all lexing by emitting an error and returning `nil`. // Note that any value that is a character is escaped if it's a special -// character (new lines, tabs, etc.). +// character (newlines, tabs, etc.). func (lx *lexer) errorf(format string, values ...interface{}) stateFn { lx.items <- item{ itemError, @@ -184,7 +229,6 @@ func lexTop(lx *lexer) stateFn { if isWhitespace(r) || isNL(r) { return lexSkip(lx, lexTop) } - switch r { case commentStart: lx.push(lexTop) @@ -193,7 +237,7 @@ func lexTop(lx *lexer) stateFn { return lexTableStart case eof: if lx.pos > lx.start { - return lx.errorf("Unexpected EOF.") + return lx.errorf("unexpected EOF") } lx.emit(itemEOF) return nil @@ -208,12 +252,12 @@ func lexTop(lx *lexer) stateFn { // lexTopEnd is entered whenever a top-level item has been consumed. (A value // or a table.) It must see only whitespace, and will turn back to lexTop -// upon a new line. If it sees EOF, it will quit the lexer successfully. +// upon a newline. If it sees EOF, it will quit the lexer successfully. func lexTopEnd(lx *lexer) stateFn { r := lx.next() switch { case r == commentStart: - // a comment will read to a new line for us. + // a comment will read to a newline for us. lx.push(lexTop) return lexCommentStart case isWhitespace(r): @@ -222,11 +266,11 @@ func lexTopEnd(lx *lexer) stateFn { lx.ignore() return lexTop case r == eof: - lx.ignore() - return lexTop + lx.emit(itemEOF) + return nil } - return lx.errorf("Expected a top-level item to end with a new line, "+ - "comment or EOF, but got %q instead.", r) + return lx.errorf("expected a top-level item to end with a newline, "+ + "comment, or EOF, but got %q instead", r) } // lexTable lexes the beginning of a table. Namely, it makes sure that @@ -253,50 +297,47 @@ func lexTableEnd(lx *lexer) stateFn { func lexArrayTableEnd(lx *lexer) stateFn { if r := lx.next(); r != arrayTableEnd { - return lx.errorf("Expected end of table array name delimiter %q, "+ - "but got %q instead.", arrayTableEnd, r) + return lx.errorf("expected end of table array name delimiter %q, "+ + "but got %q instead", arrayTableEnd, r) } lx.emit(itemArrayTableEnd) return lexTopEnd } func lexTableNameStart(lx *lexer) stateFn { + lx.skip(isWhitespace) switch r := lx.peek(); { case r == tableEnd || r == eof: - return lx.errorf("Unexpected end of table name. (Table names cannot " + - "be empty.)") + return lx.errorf("unexpected end of table name " + + "(table names cannot be empty)") case r == tableSep: - return lx.errorf("Unexpected table separator. (Table names cannot " + - "be empty.)") + return lx.errorf("unexpected table separator " + + "(table names cannot be empty)") case r == stringStart || r == rawStringStart: lx.ignore() lx.push(lexTableNameEnd) return lexValue // reuse string lexing - case isWhitespace(r): - return lexTableNameStart default: return lexBareTableName } } -// lexTableName lexes the name of a table. It assumes that at least one +// lexBareTableName lexes the name of a table. It assumes that at least one // valid character for the table has already been read. func lexBareTableName(lx *lexer) stateFn { - switch r := lx.next(); { - case isBareKeyChar(r): + r := lx.next() + if isBareKeyChar(r) { return lexBareTableName - case r == tableSep || r == tableEnd: - lx.backup() - lx.emitTrim(itemText) - return lexTableNameEnd - default: - return lx.errorf("Bare keys cannot contain %q.", r) } + lx.backup() + lx.emit(itemText) + return lexTableNameEnd } // lexTableNameEnd reads the end of a piece of a table name, optionally // consuming whitespace. func lexTableNameEnd(lx *lexer) stateFn { + lx.skip(isWhitespace) switch r := lx.next(); { case isWhitespace(r): return lexTableNameEnd @@ -306,8 +347,8 @@ func lexTableNameEnd(lx *lexer) stateFn { case r == tableEnd: return lx.pop() default: - return lx.errorf("Expected '.' or ']' to end table name, but got %q "+ - "instead.", r) + return lx.errorf("expected '.' or ']' to end table name, "+ + "but got %q instead", r) } } @@ -317,7 +358,7 @@ func lexKeyStart(lx *lexer) stateFn { r := lx.peek() switch { case r == keySep: - return lx.errorf("Unexpected key separator %q.", keySep) + return lx.errorf("unexpected key separator %q", keySep) case isWhitespace(r) || isNL(r): lx.next() return lexSkip(lx, lexKeyStart) @@ -340,14 +381,15 @@ func lexBareKey(lx *lexer) stateFn { case isBareKeyChar(r): return lexBareKey case isWhitespace(r): - lx.emitTrim(itemText) + lx.backup() + lx.emit(itemText) return lexKeyEnd case r == keySep: lx.backup() - lx.emitTrim(itemText) + lx.emit(itemText) return lexKeyEnd default: - return lx.errorf("Bare keys cannot contain %q.", r) + return lx.errorf("bare keys cannot contain %q", r) } } @@ -360,7 +402,7 @@ func lexKeyEnd(lx *lexer) stateFn { case isWhitespace(r): return lexSkip(lx, lexKeyEnd) default: - return lx.errorf("Expected key separator %q, but got %q instead.", + return lx.errorf("expected key separator %q, but got %q instead", keySep, r) } } @@ -369,20 +411,26 @@ func lexKeyEnd(lx *lexer) stateFn { // lexValue will ignore whitespace. // After a value is lexed, the last state on the next is popped and returned. func lexValue(lx *lexer) stateFn { - // We allow whitespace to precede a value, but NOT new lines. - // In array syntax, the array states are responsible for ignoring new - // lines. + // We allow whitespace to precede a value, but NOT newlines. + // In array syntax, the array states are responsible for ignoring newlines. r := lx.next() - if isWhitespace(r) { - return lexSkip(lx, lexValue) - } - switch { - case r == arrayStart: + case isWhitespace(r): + return lexSkip(lx, lexValue) + case isDigit(r): + lx.backup() // avoid an extra state and use the same as above + return lexNumberOrDateStart + } + switch r { + case arrayStart: lx.ignore() lx.emit(itemArray) return lexArrayValue - case r == stringStart: + case inlineTableStart: + lx.ignore() + lx.emit(itemInlineTableStart) + return lexInlineTableValue + case stringStart: if lx.accept(stringStart) { if lx.accept(stringStart) { lx.ignore() // Ignore """ @@ -392,7 +440,7 @@ func lexValue(lx *lexer) stateFn { } lx.ignore() // ignore the '"' return lexString - case r == rawStringStart: + case rawStringStart: if lx.accept(rawStringStart) { if lx.accept(rawStringStart) { lx.ignore() // Ignore """ @@ -402,23 +450,24 @@ func lexValue(lx *lexer) stateFn { } lx.ignore() // ignore the "'" return lexRawString - case r == 't': - return lexTrue - case r == 'f': - return lexFalse - case r == '-': + case '+', '-': return lexNumberStart - case isDigit(r): - lx.backup() // avoid an extra state and use the same as above - return lexNumberOrDateStart - case r == '.': // special error case, be kind to users - return lx.errorf("Floats must start with a digit, not '.'.") + case '.': // special error case, be kind to users + return lx.errorf("floats must start with a digit, not '.'") } - return lx.errorf("Expected value but found %q instead.", r) + if unicode.IsLetter(r) { + // Be permissive here; lexBool will give a nice error if the + // user wrote something like + // x = foo + // (i.e. not 'true' or 'false' but is something else word-like.) + lx.backup() + return lexBool + } + return lx.errorf("expected value but found %q instead", r) } // lexArrayValue consumes one value in an array. It assumes that '[' or ',' -// have already been consumed. All whitespace and new lines are ignored. +// have already been consumed. All whitespace and newlines are ignored. func lexArrayValue(lx *lexer) stateFn { r := lx.next() switch { @@ -427,10 +476,11 @@ func lexArrayValue(lx *lexer) stateFn { case r == commentStart: lx.push(lexArrayValue) return lexCommentStart - case r == arrayValTerm: - return lx.errorf("Unexpected array value terminator %q.", - arrayValTerm) + case r == comma: + return lx.errorf("unexpected comma") case r == arrayEnd: + // NOTE(caleb): The spec isn't clear about whether you can have + // a trailing comma or not, so we'll allow it. return lexArrayEnd } @@ -439,8 +489,9 @@ func lexArrayValue(lx *lexer) stateFn { return lexValue } -// lexArrayValueEnd consumes the cruft between values of an array. Namely, -// it ignores whitespace and expects either a ',' or a ']'. +// lexArrayValueEnd consumes everything between the end of an array value and +// the next value (or the end of the array): it ignores whitespace and newlines +// and expects either a ',' or a ']'. func lexArrayValueEnd(lx *lexer) stateFn { r := lx.next() switch { @@ -449,31 +500,88 @@ func lexArrayValueEnd(lx *lexer) stateFn { case r == commentStart: lx.push(lexArrayValueEnd) return lexCommentStart - case r == arrayValTerm: + case r == comma: lx.ignore() return lexArrayValue // move on to the next value case r == arrayEnd: return lexArrayEnd } - return lx.errorf("Expected an array value terminator %q or an array "+ - "terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r) + return lx.errorf( + "expected a comma or array terminator %q, but got %q instead", + arrayEnd, r, + ) } -// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has -// just been consumed. +// lexArrayEnd finishes the lexing of an array. +// It assumes that a ']' has just been consumed. func lexArrayEnd(lx *lexer) stateFn { lx.ignore() lx.emit(itemArrayEnd) return lx.pop() } +// lexInlineTableValue consumes one key/value pair in an inline table. +// It assumes that '{' or ',' have already been consumed. Whitespace is ignored. +func lexInlineTableValue(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexInlineTableValue) + case isNL(r): + return lx.errorf("newlines not allowed within inline tables") + case r == commentStart: + lx.push(lexInlineTableValue) + return lexCommentStart + case r == comma: + return lx.errorf("unexpected comma") + case r == inlineTableEnd: + return lexInlineTableEnd + } + lx.backup() + lx.push(lexInlineTableValueEnd) + return lexKeyStart +} + +// lexInlineTableValueEnd consumes everything between the end of an inline table +// key/value pair and the next pair (or the end of the table): +// it ignores whitespace and expects either a ',' or a '}'. +func lexInlineTableValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexInlineTableValueEnd) + case isNL(r): + return lx.errorf("newlines not allowed within inline tables") + case r == commentStart: + lx.push(lexInlineTableValueEnd) + return lexCommentStart + case r == comma: + lx.ignore() + return lexInlineTableValue + case r == inlineTableEnd: + return lexInlineTableEnd + } + return lx.errorf("expected a comma or an inline table terminator %q, "+ + "but got %q instead", inlineTableEnd, r) +} + +// lexInlineTableEnd finishes the lexing of an inline table. +// It assumes that a '}' has just been consumed. +func lexInlineTableEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemInlineTableEnd) + return lx.pop() +} + // lexString consumes the inner contents of a string. It assumes that the // beginning '"' has already been consumed and ignored. func lexString(lx *lexer) stateFn { r := lx.next() switch { + case r == eof: + return lx.errorf("unexpected EOF") case isNL(r): - return lx.errorf("Strings cannot contain new lines.") + return lx.errorf("strings cannot contain newlines") case r == '\\': lx.push(lexString) return lexStringEscape @@ -490,11 +598,12 @@ func lexString(lx *lexer) stateFn { // lexMultilineString consumes the inner contents of a string. It assumes that // the beginning '"""' has already been consumed and ignored. func lexMultilineString(lx *lexer) stateFn { - r := lx.next() - switch { - case r == '\\': + switch lx.next() { + case eof: + return lx.errorf("unexpected EOF") + case '\\': return lexMultilineStringEscape - case r == stringEnd: + case stringEnd: if lx.accept(stringEnd) { if lx.accept(stringEnd) { lx.backup() @@ -518,8 +627,10 @@ func lexMultilineString(lx *lexer) stateFn { func lexRawString(lx *lexer) stateFn { r := lx.next() switch { + case r == eof: + return lx.errorf("unexpected EOF") case isNL(r): - return lx.errorf("Strings cannot contain new lines.") + return lx.errorf("strings cannot contain newlines") case r == rawStringEnd: lx.backup() lx.emit(itemRawString) @@ -531,12 +642,13 @@ func lexRawString(lx *lexer) stateFn { } // lexMultilineRawString consumes a raw string. Nothing can be escaped in such -// a string. It assumes that the beginning "'" has already been consumed and +// a string. It assumes that the beginning "'''" has already been consumed and // ignored. func lexMultilineRawString(lx *lexer) stateFn { - r := lx.next() - switch { - case r == rawStringEnd: + switch lx.next() { + case eof: + return lx.errorf("unexpected EOF") + case rawStringEnd: if lx.accept(rawStringEnd) { if lx.accept(rawStringEnd) { lx.backup() @@ -560,13 +672,11 @@ func lexMultilineRawString(lx *lexer) stateFn { func lexMultilineStringEscape(lx *lexer) stateFn { // Handle the special case first: if isNL(lx.next()) { - lx.next() return lexMultilineString - } else { - lx.backup() - lx.push(lexMultilineString) - return lexStringEscape(lx) } + lx.backup() + lx.push(lexMultilineString) + return lexStringEscape(lx) } func lexStringEscape(lx *lexer) stateFn { @@ -591,10 +701,9 @@ func lexStringEscape(lx *lexer) stateFn { case 'U': return lexLongUnicodeEscape } - return lx.errorf("Invalid escape character %q. Only the following "+ + return lx.errorf("invalid escape character %q; only the following "+ "escape characters are allowed: "+ - "\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+ - "\\uXXXX and \\UXXXXXXXX.", r) + `\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r) } func lexShortUnicodeEscape(lx *lexer) stateFn { @@ -602,8 +711,8 @@ func lexShortUnicodeEscape(lx *lexer) stateFn { for i := 0; i < 4; i++ { r = lx.next() if !isHexadecimal(r) { - return lx.errorf("Expected four hexadecimal digits after '\\u', "+ - "but got '%s' instead.", lx.current()) + return lx.errorf(`expected four hexadecimal digits after '\u', `+ + "but got %q instead", lx.current()) } } return lx.pop() @@ -614,40 +723,43 @@ func lexLongUnicodeEscape(lx *lexer) stateFn { for i := 0; i < 8; i++ { r = lx.next() if !isHexadecimal(r) { - return lx.errorf("Expected eight hexadecimal digits after '\\U', "+ - "but got '%s' instead.", lx.current()) + return lx.errorf(`expected eight hexadecimal digits after '\U', `+ + "but got %q instead", lx.current()) } } return lx.pop() } -// lexNumberOrDateStart consumes either a (positive) integer, float or -// datetime. It assumes that NO negative sign has been consumed. +// lexNumberOrDateStart consumes either an integer, a float, or datetime. func lexNumberOrDateStart(lx *lexer) stateFn { r := lx.next() - if !isDigit(r) { - if r == '.' { - return lx.errorf("Floats must start with a digit, not '.'.") - } else { - return lx.errorf("Expected a digit but got %q.", r) - } + if isDigit(r) { + return lexNumberOrDate } - return lexNumberOrDate + switch r { + case '_': + return lexNumber + case 'e', 'E': + return lexFloat + case '.': + return lx.errorf("floats must start with a digit, not '.'") + } + return lx.errorf("expected a digit but got %q", r) } -// lexNumberOrDate consumes either a (positive) integer, float or datetime. +// lexNumberOrDate consumes either an integer, float or datetime. func lexNumberOrDate(lx *lexer) stateFn { r := lx.next() - switch { - case r == '-': - if lx.pos-lx.start != 5 { - return lx.errorf("All ISO8601 dates must be in full Zulu form.") - } - return lexDateAfterYear - case isDigit(r): + if isDigit(r) { return lexNumberOrDate - case r == '.': - return lexFloatStart + } + switch r { + case '-': + return lexDatetime + case '_': + return lexNumber + case '.', 'e', 'E': + return lexFloat } lx.backup() @@ -655,46 +767,34 @@ func lexNumberOrDate(lx *lexer) stateFn { return lx.pop() } -// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format. -// It assumes that "YYYY-" has already been consumed. -func lexDateAfterYear(lx *lexer) stateFn { - formats := []rune{ - // digits are '0'. - // everything else is direct equality. - '0', '0', '-', '0', '0', - 'T', - '0', '0', ':', '0', '0', ':', '0', '0', - 'Z', +// lexDatetime consumes a Datetime, to a first approximation. +// The parser validates that it matches one of the accepted formats. +func lexDatetime(lx *lexer) stateFn { + r := lx.next() + if isDigit(r) { + return lexDatetime } - for _, f := range formats { - r := lx.next() - if f == '0' { - if !isDigit(r) { - return lx.errorf("Expected digit in ISO8601 datetime, "+ - "but found %q instead.", r) - } - } else if f != r { - return lx.errorf("Expected %q in ISO8601 datetime, "+ - "but found %q instead.", f, r) - } + switch r { + case '-', 'T', ':', '.', 'Z': + return lexDatetime } + + lx.backup() lx.emit(itemDatetime) return lx.pop() } -// lexNumberStart consumes either an integer or a float. It assumes that -// a negative sign has already been read, but that *no* digits have been -// consumed. lexNumberStart will move to the appropriate integer or float -// states. +// lexNumberStart consumes either an integer or a float. It assumes that a sign +// has already been read, but that *no* digits have been consumed. +// lexNumberStart will move to the appropriate integer or float states. func lexNumberStart(lx *lexer) stateFn { - // we MUST see a digit. Even floats have to start with a digit. + // We MUST see a digit. Even floats have to start with a digit. r := lx.next() if !isDigit(r) { if r == '.' { - return lx.errorf("Floats must start with a digit, not '.'.") - } else { - return lx.errorf("Expected a digit but got %q.", r) + return lx.errorf("floats must start with a digit, not '.'") } + return lx.errorf("expected a digit but got %q", r) } return lexNumber } @@ -702,11 +802,14 @@ func lexNumberStart(lx *lexer) stateFn { // lexNumber consumes an integer or a float after seeing the first digit. func lexNumber(lx *lexer) stateFn { r := lx.next() - switch { - case isDigit(r): + if isDigit(r) { return lexNumber - case r == '.': - return lexFloatStart + } + switch r { + case '_': + return lexNumber + case '.', 'e', 'E': + return lexFloat } lx.backup() @@ -714,60 +817,42 @@ func lexNumber(lx *lexer) stateFn { return lx.pop() } -// lexFloatStart starts the consumption of digits of a float after a '.'. -// Namely, at least one digit is required. -func lexFloatStart(lx *lexer) stateFn { - r := lx.next() - if !isDigit(r) { - return lx.errorf("Floats must have a digit after the '.', but got "+ - "%q instead.", r) - } - return lexFloat -} - -// lexFloat consumes the digits of a float after a '.'. -// Assumes that one digit has been consumed after a '.' already. +// lexFloat consumes the elements of a float. It allows any sequence of +// float-like characters, so floats emitted by the lexer are only a first +// approximation and must be validated by the parser. func lexFloat(lx *lexer) stateFn { r := lx.next() if isDigit(r) { return lexFloat } + switch r { + case '_', '.', '-', '+', 'e', 'E': + return lexFloat + } lx.backup() lx.emit(itemFloat) return lx.pop() } -// lexConst consumes the s[1:] in s. It assumes that s[0] has already been -// consumed. -func lexConst(lx *lexer, s string) stateFn { - for i := range s[1:] { - if r := lx.next(); r != rune(s[i+1]) { - return lx.errorf("Expected %q, but found %q instead.", s[:i+1], - s[:i]+string(r)) +// lexBool consumes a bool string: 'true' or 'false. +func lexBool(lx *lexer) stateFn { + var rs []rune + for { + r := lx.next() + if !unicode.IsLetter(r) { + lx.backup() + break } + rs = append(rs, r) } - return nil -} - -// lexTrue consumes the "rue" in "true". It assumes that 't' has already -// been consumed. -func lexTrue(lx *lexer) stateFn { - if fn := lexConst(lx, "true"); fn != nil { - return fn + s := string(rs) + switch s { + case "true", "false": + lx.emit(itemBool) + return lx.pop() } - lx.emit(itemBool) - return lx.pop() -} - -// lexFalse consumes the "alse" in "false". It assumes that 'f' has already -// been consumed. -func lexFalse(lx *lexer) stateFn { - if fn := lexConst(lx, "false"); fn != nil { - return fn - } - lx.emit(itemBool) - return lx.pop() + return lx.errorf("expected value but found %q instead", s) } // lexCommentStart begins the lexing of a comment. It will emit @@ -779,7 +864,7 @@ func lexCommentStart(lx *lexer) stateFn { } // lexComment lexes an entire comment. It assumes that '#' has been consumed. -// It will consume *up to* the first new line character, and pass control +// It will consume *up to* the first newline character, and pass control // back to the last state on the stack. func lexComment(lx *lexer) stateFn { r := lx.peek() @@ -837,13 +922,7 @@ func (itype itemType) String() string { return "EOF" case itemText: return "Text" - case itemString: - return "String" - case itemRawString: - return "String" - case itemMultilineString: - return "String" - case itemRawMultilineString: + case itemString, itemRawString, itemMultilineString, itemRawMultilineString: return "String" case itemBool: return "Bool" diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/parse.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/parse.go similarity index 80% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/parse.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/parse.go index c6069be..50869ef 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/parse.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/parse.go @@ -2,7 +2,6 @@ package toml import ( "fmt" - "log" "strconv" "strings" "time" @@ -81,7 +80,7 @@ func (p *parser) next() item { } func (p *parser) bug(format string, v ...interface{}) { - log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...)) + panic(fmt.Sprintf("BUG: "+format+"\n\n", v...)) } func (p *parser) expect(typ itemType) item { @@ -179,10 +178,18 @@ func (p *parser) value(it item) (interface{}, tomlType) { } p.bug("Expected boolean value, but got '%s'.", it.val) case itemInteger: - num, err := strconv.ParseInt(it.val, 10, 64) + if !numUnderscoresOK(it.val) { + p.panicf("Invalid integer %q: underscores must be surrounded by digits", + it.val) + } + val := strings.Replace(it.val, "_", "", -1) + num, err := strconv.ParseInt(val, 10, 64) if err != nil { - // See comment below for floats describing why we make a - // distinction between a bug and a user error. + // Distinguish integer values. Normally, it'd be a bug if the lexer + // provides an invalid integer, but it's possible that the number is + // out of range of valid values (which the lexer cannot determine). + // So mark the former as a bug but the latter as a legitimate user + // error. if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { @@ -194,29 +201,57 @@ func (p *parser) value(it item) (interface{}, tomlType) { } return num, p.typeOfPrimitive(it) case itemFloat: - num, err := strconv.ParseFloat(it.val, 64) + parts := strings.FieldsFunc(it.val, func(r rune) bool { + switch r { + case '.', 'e', 'E': + return true + } + return false + }) + for _, part := range parts { + if !numUnderscoresOK(part) { + p.panicf("Invalid float %q: underscores must be "+ + "surrounded by digits", it.val) + } + } + if !numPeriodsOK(it.val) { + // As a special case, numbers like '123.' or '1.e2', + // which are valid as far as Go/strconv are concerned, + // must be rejected because TOML says that a fractional + // part consists of '.' followed by 1+ digits. + p.panicf("Invalid float %q: '.' must be followed "+ + "by one or more digits", it.val) + } + val := strings.Replace(it.val, "_", "", -1) + num, err := strconv.ParseFloat(val, 64) if err != nil { - // Distinguish float values. Normally, it'd be a bug if the lexer - // provides an invalid float, but it's possible that the float is - // out of range of valid values (which the lexer cannot determine). - // So mark the former as a bug but the latter as a legitimate user - // error. - // - // This is also true for integers. if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { p.panicf("Float '%s' is out of the range of 64-bit "+ "IEEE-754 floating-point numbers.", it.val) } else { - p.bug("Expected float value, but got '%s'.", it.val) + p.panicf("Invalid float value: %q", it.val) } } return num, p.typeOfPrimitive(it) case itemDatetime: - t, err := time.Parse("2006-01-02T15:04:05Z", it.val) - if err != nil { - p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val) + var t time.Time + var ok bool + var err error + for _, format := range []string{ + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05", + "2006-01-02", + } { + t, err = time.ParseInLocation(format, it.val, time.Local) + if err == nil { + ok = true + break + } + } + if !ok { + p.panicf("Invalid TOML Datetime: %q.", it.val) } return t, p.typeOfPrimitive(it) case itemArray: @@ -234,11 +269,75 @@ func (p *parser) value(it item) (interface{}, tomlType) { types = append(types, typ) } return array, p.typeOfArray(types) + case itemInlineTableStart: + var ( + hash = make(map[string]interface{}) + outerContext = p.context + outerKey = p.currentKey + ) + + p.context = append(p.context, p.currentKey) + p.currentKey = "" + for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() { + if it.typ != itemKeyStart { + p.bug("Expected key start but instead found %q, around line %d", + it.val, p.approxLine) + } + if it.typ == itemCommentStart { + p.expect(itemText) + continue + } + + // retrieve key + k := p.next() + p.approxLine = k.line + kname := p.keyString(k) + + // retrieve value + p.currentKey = kname + val, typ := p.value(p.next()) + // make sure we keep metadata up to date + p.setType(kname, typ) + p.ordered = append(p.ordered, p.context.add(p.currentKey)) + hash[kname] = val + } + p.context = outerContext + p.currentKey = outerKey + return hash, tomlHash } p.bug("Unexpected value type: %s", it.typ) panic("unreachable") } +// numUnderscoresOK checks whether each underscore in s is surrounded by +// characters that are not underscores. +func numUnderscoresOK(s string) bool { + accept := false + for _, r := range s { + if r == '_' { + if !accept { + return false + } + accept = false + continue + } + accept = true + } + return accept +} + +// numPeriodsOK checks whether every period in s is followed by a digit. +func numPeriodsOK(s string) bool { + period := false + for _, r := range s { + if period && !isDigit(r) { + return false + } + period = r == '.' + } + return !period +} + // establishContext sets the current context of the parser, // where the context is either a hash or an array of hashes. Which one is // set depends on the value of the `array` parameter. @@ -401,7 +500,7 @@ func stripFirstNewline(s string) string { if len(s) == 0 || s[0] != '\n' { return s } - return s[1:len(s)] + return s[1:] } func stripEscapedWhitespace(s string) string { @@ -481,12 +580,7 @@ func (p *parser) asciiEscapeToUnicode(bs []byte) rune { p.bug("Could not parse '%s' as a hexadecimal number, but the "+ "lexer claims it's OK: %s", s, err) } - - // BUG(burntsushi) - // I honestly don't understand how this works. I can't seem - // to find a way to make this fail. I figured this would fail on invalid - // UTF-8 characters like U+DCFF, but it doesn't. - if !utf8.ValidString(string(rune(hex))) { + if !utf8.ValidRune(rune(hex)) { p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s) } return rune(hex) diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_check.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/type_check.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_check.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/type_check.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_fields.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/type_fields.go similarity index 96% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_fields.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/type_fields.go index 7592f87..608997c 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_fields.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/BurntSushi/toml/type_fields.go @@ -92,11 +92,11 @@ func typeFields(t reflect.Type) []field { // Scan f.typ for fields to include. for i := 0; i < f.typ.NumField(); i++ { sf := f.typ.Field(i) - if sf.PkgPath != "" { // unexported + if sf.PkgPath != "" && !sf.Anonymous { // unexported continue } - name := sf.Tag.Get("toml") - if name == "-" { + opts := getOptions(sf.Tag) + if opts.skip { continue } index := make([]int, len(f.index)+1) @@ -110,8 +110,9 @@ func typeFields(t reflect.Type) []field { } // Record found field and index sequence. - if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { - tagged := name != "" + if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { + tagged := opts.name != "" + name := opts.name if name == "" { name = sf.Name } diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.github/ISSUE_TEMPLATE.md b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..d9771f1 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,21 @@ +### Issue description +Tell us what should happen and what happens instead + +### Example code +```go +If possible, please enter some example code here to reproduce the issue. +``` + +### Error log +``` +If you have an error log, please paste it here. +``` + +### Configuration +*Driver version (or git SHA):* + +*Go version:* run `go version` in your console + +*Server version:* E.g. MySQL 5.6, MariaDB 10.0.20 + +*Server OS:* E.g. Debian 8.1 (Jessie), Windows 10 diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.github/PULL_REQUEST_TEMPLATE.md b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..6f5c7eb --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,9 @@ +### Description +Please explain the changes you made here. + +### Checklist +- [ ] Code compiles correctly +- [ ] Created tests which fail without the change (if possible) +- [ ] All tests passing +- [ ] Extended the README / documentation, if necessary +- [ ] Added myself / the copyright holder to the AUTHORS file diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.gitignore b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.gitignore new file mode 100644 index 0000000..2de28da --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +Icon? +ehthumbs.db +Thumbs.db +.idea diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis.yml b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis.yml new file mode 100644 index 0000000..eae311b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis.yml @@ -0,0 +1,128 @@ +sudo: false +language: go +go: + - 1.9.x + - 1.10.x + - 1.11.x + - 1.12.x + - master + +before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + +before_script: + - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf + - sudo service mysql restart + - .travis/wait_mysql.sh + - mysql -e 'create database gotest;' + +matrix: + include: + - env: DB=MYSQL8 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mysql:8.0 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MYSQL57 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mysql:5.7 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MARIA55 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mariadb:5.5 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MARIA10_1 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mariadb:10.1 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - os: osx + osx_image: xcode10.1 + addons: + homebrew: + packages: + - mysql + go: 1.12.x + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + before_script: + - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB\nlocal_infile=1" >> /usr/local/etc/my.cnf + - mysql.server start + - mysql -uroot -e 'CREATE USER gotest IDENTIFIED BY "secret"' + - mysql -uroot -e 'GRANT ALL ON *.* TO gotest' + - mysql -uroot -e 'create database gotest;' + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3306 + - export MYSQL_TEST_CONCURRENT=1 + +script: + - go test -v -covermode=count -coverprofile=coverage.out + - go vet ./... + - .travis/gofmt.sh +after_script: + - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/docker.cnf b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/docker.cnf new file mode 100644 index 0000000..e57754e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/docker.cnf @@ -0,0 +1,5 @@ +[client] +user = gotest +password = secret +host = 127.0.0.1 +port = 3307 diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/gofmt.sh b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/gofmt.sh new file mode 100755 index 0000000..9bf0d16 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/gofmt.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -ev + +# Only check for go1.10+ since the gofmt style changed +if [[ $(go version) =~ go1\.([0-9]+) ]] && ((${BASH_REMATCH[1]} >= 10)); then + test -z "$(gofmt -d -s . | tee /dev/stderr)" +fi diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/wait_mysql.sh b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/wait_mysql.sh new file mode 100755 index 0000000..e87993e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/.travis/wait_mysql.sh @@ -0,0 +1,8 @@ +#!/bin/sh +while : +do + if mysql -e 'select version()' 2>&1 | grep 'version()\|ERROR 2059 (HY000):'; then + break + fi + sleep 3 +done diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/AUTHORS new file mode 100644 index 0000000..bfe74c4 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -0,0 +1,101 @@ +# This is the official list of Go-MySQL-Driver authors for copyright purposes. + +# If you are submitting a patch, please add your name or the name of the +# organization which holds the copyright to this list in alphabetical order. + +# Names should be added to this file as +# Name +# The email address is not required for organizations. +# Please keep the list sorted. + + +# Individual Persons + +Aaron Hopkins +Achille Roussel +Alexey Palazhchenko +Andrew Reid +Arne Hormann +Asta Xie +Bulat Gaifullin +Carlos Nieto +Chris Moos +Craig Wilson +Daniel Montoya +Daniel Nichter +Daniël van Eeden +Dave Protasowski +DisposaBoy +Egor Smolyakov +Erwan Martin +Evan Shaw +Frederick Mayle +Gustavo Kristic +Hajime Nakagami +Hanno Braun +Henri Yandell +Hirotaka Yamamoto +Huyiguang +ICHINOSE Shogo +Ilia Cimpoes +INADA Naoki +Jacek Szwec +James Harr +Jeff Hodges +Jeffrey Charles +Jerome Meyer +Jian Zhen +Joshua Prunier +Julien Lefevre +Julien Schmidt +Justin Li +Justin Nuß +Kamil Dziedzic +Kevin Malachowski +Kieron Woodhouse +Lennart Rudolph +Leonardo YongUk Kim +Linh Tran Tuan +Lion Yang +Luca Looz +Lucas Liu +Luke Scott +Maciej Zimnoch +Michael Woolnough +Nicola Peduzzi +Olivier Mengué +oscarzhao +Paul Bonser +Peter Schultz +Rebecca Chin +Reed Allman +Richard Wilkes +Robert Russell +Runrioter Wung +Shuode Li +Simon J Mudd +Soroush Pour +Stan Putrya +Stanley Gunawan +Steven Hartland +Thomas Wodarek +Tim Ruffles +Tom Jenkinson +Xiangyu Hu +Xiaobing Jiang +Xiuming Chen +Zhenye Xie + +# Organizations + +Barracuda Networks, Inc. +Counting Ltd. +Facebook Inc. +GitHub Inc. +Google Inc. +InfoSum Ltd. +Keybase Inc. +Multiplay Ltd. +Percona LLC +Pivotal Inc. +Stripe Inc. diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md new file mode 100644 index 0000000..2d87d74 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md @@ -0,0 +1,167 @@ +## Version 1.4 (2018-06-03) + +Changes: + + - Documentation fixes (#530, #535, #567) + - Refactoring (#575, #579, #580, #581, #603, #615, #704) + - Cache column names (#444) + - Sort the DSN parameters in DSNs generated from a config (#637) + - Allow native password authentication by default (#644) + - Use the default port if it is missing in the DSN (#668) + - Removed the `strict` mode (#676) + - Do not query `max_allowed_packet` by default (#680) + - Dropped support Go 1.6 and lower (#696) + - Updated `ConvertValue()` to match the database/sql/driver implementation (#760) + - Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783) + - Improved the compatibility of the authentication system (#807) + +New Features: + + - Multi-Results support (#537) + - `rejectReadOnly` DSN option (#604) + - `context.Context` support (#608, #612, #627, #761) + - Transaction isolation level support (#619, #744) + - Read-Only transactions support (#618, #634) + - `NewConfig` function which initializes a config with default values (#679) + - Implemented the `ColumnType` interfaces (#667, #724) + - Support for custom string types in `ConvertValue` (#623) + - Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710) + - `caching_sha2_password` authentication plugin support (#794, #800, #801, #802) + - Implemented `driver.SessionResetter` (#779) + - `sha256_password` authentication plugin support (#808) + +Bugfixes: + + - Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718) + - Fixed LOAD LOCAL DATA INFILE for empty files (#590) + - Removed columns definition cache since it sometimes cached invalid data (#592) + - Don't mutate registered TLS configs (#600) + - Make RegisterTLSConfig concurrency-safe (#613) + - Handle missing auth data in the handshake packet correctly (#646) + - Do not retry queries when data was written to avoid data corruption (#302, #736) + - Cache the connection pointer for error handling before invalidating it (#678) + - Fixed imports for appengine/cloudsql (#700) + - Fix sending STMT_LONG_DATA for 0 byte data (#734) + - Set correct capacity for []bytes read from length-encoded strings (#766) + - Make RegisterDial concurrency-safe (#773) + + +## Version 1.3 (2016-12-01) + +Changes: + + - Go 1.1 is no longer supported + - Use decimals fields in MySQL to format time types (#249) + - Buffer optimizations (#269) + - TLS ServerName defaults to the host (#283) + - Refactoring (#400, #410, #437) + - Adjusted documentation for second generation CloudSQL (#485) + - Documented DSN system var quoting rules (#502) + - Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512) + +New Features: + + - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) + - Support for returning table alias on Columns() (#289, #359, #382) + - Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490) + - Support for uint64 parameters with high bit set (#332, #345) + - Cleartext authentication plugin support (#327) + - Exported ParseDSN function and the Config struct (#403, #419, #429) + - Read / Write timeouts (#401) + - Support for JSON field type (#414) + - Support for multi-statements and multi-results (#411, #431) + - DSN parameter to set the driver-side max_allowed_packet value manually (#489) + - Native password authentication plugin support (#494, #524) + +Bugfixes: + + - Fixed handling of queries without columns and rows (#255) + - Fixed a panic when SetKeepAlive() failed (#298) + - Handle ERR packets while reading rows (#321) + - Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349) + - Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356) + - Actually zero out bytes in handshake response (#378) + - Fixed race condition in registering LOAD DATA INFILE handler (#383) + - Fixed tests with MySQL 5.7.9+ (#380) + - QueryUnescape TLS config names (#397) + - Fixed "broken pipe" error by writing to closed socket (#390) + - Fixed LOAD LOCAL DATA INFILE buffering (#424) + - Fixed parsing of floats into float64 when placeholders are used (#434) + - Fixed DSN tests with Go 1.7+ (#459) + - Handle ERR packets while waiting for EOF (#473) + - Invalidate connection on error while discarding additional results (#513) + - Allow terminating packets of length 0 (#516) + + +## Version 1.2 (2014-06-03) + +Changes: + + - We switched back to a "rolling release". `go get` installs the current master branch again + - Version v1 of the driver will not be maintained anymore. Go 1.0 is no longer supported by this driver + - Exported errors to allow easy checking from application code + - Enabled TCP Keepalives on TCP connections + - Optimized INFILE handling (better buffer size calculation, lazy init, ...) + - The DSN parser also checks for a missing separating slash + - Faster binary date / datetime to string formatting + - Also exported the MySQLWarning type + - mysqlConn.Close returns the first error encountered instead of ignoring all errors + - writePacket() automatically writes the packet size to the header + - readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets + +New Features: + + - `RegisterDial` allows the usage of a custom dial function to establish the network connection + - Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter + - Logging of critical errors is configurable with `SetLogger` + - Google CloudSQL support + +Bugfixes: + + - Allow more than 32 parameters in prepared statements + - Various old_password fixes + - Fixed TestConcurrent test to pass Go's race detection + - Fixed appendLengthEncodedInteger for large numbers + - Renamed readLengthEnodedString to readLengthEncodedString and skipLengthEnodedString to skipLengthEncodedString (fixed typo) + + +## Version 1.1 (2013-11-02) + +Changes: + + - Go-MySQL-Driver now requires Go 1.1 + - Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore + - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors + - `[]byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")` + - DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'. + - Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries + - Optimized the buffer for reading + - stmt.Query now caches column metadata + - New Logo + - Changed the copyright header to include all contributors + - Improved the LOAD INFILE documentation + - The driver struct is now exported to make the driver directly accessible + - Refactored the driver tests + - Added more benchmarks and moved all to a separate file + - Other small refactoring + +New Features: + + - Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure + - Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs + - Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used + +Bugfixes: + + - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification + - Convert to DB timezone when inserting `time.Time` + - Splitted packets (more than 16MB) are now merged correctly + - Fixed false positive `io.EOF` errors when the data was fully read + - Avoid panics on reuse of closed connections + - Fixed empty string producing false nil values + - Fixed sign byte for positive TIME fields + + +## Version 1.0 (2013-05-14) + +Initial Release diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md new file mode 100644 index 0000000..8fe16bc --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# Contributing Guidelines + +## Reporting Issues + +Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed). + +## Contributing Code + +By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file. +Don't forget to add yourself to the AUTHORS file. + +### Code Review + +Everyone is invited to review and comment on pull requests. +If it looks fine to you, comment with "LGTM" (Looks good to me). + +If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes. + +Before merging the Pull Request, at least one [team member](https://github.com/go-sql-driver?tab=members) must have commented with "LGTM". + +## Development Ideas + +If you are looking for ideas for code contributions, please check our [Development Ideas](https://github.com/go-sql-driver/mysql/wiki/Development-Ideas) Wiki page. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/LICENSE similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/LICENSE rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/LICENSE diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/README.md b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/README.md new file mode 100644 index 0000000..c6adf1d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/README.md @@ -0,0 +1,495 @@ +# Go-MySQL-Driver + +A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package + +![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin") + +--------------------------------------- + * [Features](#features) + * [Requirements](#requirements) + * [Installation](#installation) + * [Usage](#usage) + * [DSN (Data Source Name)](#dsn-data-source-name) + * [Password](#password) + * [Protocol](#protocol) + * [Address](#address) + * [Parameters](#parameters) + * [Examples](#examples) + * [Connection pool and timeouts](#connection-pool-and-timeouts) + * [context.Context Support](#contextcontext-support) + * [ColumnType Support](#columntype-support) + * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) + * [time.Time support](#timetime-support) + * [Unicode support](#unicode-support) + * [Testing / Development](#testing--development) + * [License](#license) + +--------------------------------------- + +## Features + * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance") + * Native Go implementation. No C-bindings, just pure Go + * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc) + * Automatic handling of broken connections + * Automatic Connection Pooling *(by database/sql package)* + * Supports queries larger than 16MB + * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. + * Intelligent `LONG DATA` handling in prepared statements + * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support + * Optional `time.Time` parsing + * Optional placeholder interpolation + +## Requirements + * Go 1.9 or higher. We aim to support the 3 latest versions of Go. + * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) + +--------------------------------------- + +## Installation +Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: +```bash +$ go get -u github.com/go-sql-driver/mysql +``` +Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. + +## Usage +_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. + +Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: +```go +import "database/sql" +import _ "github.com/go-sql-driver/mysql" + +db, err := sql.Open("mysql", "user:password@/dbname") +``` + +[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples"). + + +### DSN (Data Source Name) + +The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets): +``` +[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] +``` + +A DSN in its fullest form: +``` +username:password@protocol(address)/dbname?param=value +``` + +Except for the databasename, all values are optional. So the minimal DSN is: +``` +/dbname +``` + +If you do not want to preselect a database, leave `dbname` empty: +``` +/ +``` +This has the same effect as an empty DSN string: +``` + +``` + +Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. + +#### Password +Passwords can consist of any character. Escaping is **not** necessary. + +#### Protocol +See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. +In general you should use an Unix domain socket if available and TCP otherwise for best performance. + +#### Address +For TCP and UDP networks, addresses have the form `host[:port]`. +If `port` is omitted, the default port will be used. +If `host` is a literal IPv6 address, it must be enclosed in square brackets. +The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](https://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form. + +For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`. + +#### Parameters +*Parameters are case-sensitive!* + +Notice that any of `true`, `TRUE`, `True` or `1` is accepted to stand for a true boolean value. Not surprisingly, false can be specified as any of: `false`, `FALSE`, `False` or `0`. + +##### `allowAllFiles` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. +[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) + +##### `allowCleartextPasswords` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. + +##### `allowNativePasswords` + +``` +Type: bool +Valid Values: true, false +Default: true +``` +`allowNativePasswords=false` disallows the usage of MySQL native password method. + +##### `allowOldPasswords` + +``` +Type: bool +Valid Values: true, false +Default: false +``` +`allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). + +##### `charset` + +``` +Type: string +Valid Values: +Default: none +``` + +Sets the charset used for client-server interaction (`"SET NAMES "`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). + +Usage of the `charset` parameter is discouraged because it issues additional queries to the server. +Unless you need the fallback behavior, please use `collation` instead. + +##### `collation` + +``` +Type: string +Valid Values: +Default: utf8mb4_general_ci +``` + +Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail. + +A list of valid charsets for a server is retrievable with `SHOW COLLATION`. + +The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. + +Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). + + +##### `clientFoundRows` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed. + +##### `columnsWithAlias` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +When `columnsWithAlias` is true, calls to `sql.Rows.Columns()` will return the table alias and the column name separated by a dot. For example: + +``` +SELECT u.id FROM users as u +``` + +will return `u.id` instead of just `id` if `columnsWithAlias=true`. + +##### `interpolateParams` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. + +*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* + +##### `loc` + +``` +Type: string +Valid Values: +Default: UTC +``` + +Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. + +Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter. + +Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. + +##### `maxAllowedPacket` +``` +Type: decimal number +Default: 4194304 +``` + +Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. + +##### `multiStatements` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. + +When `multiStatements` is used, `?` parameters must only be used in the first statement. + +##### `parseTime` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` +The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`. + + +##### `readTimeout` + +``` +Type: duration +Default: 0 +``` + +I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + +##### `rejectReadOnly` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + + +`rejectReadOnly=true` causes the driver to reject read-only connections. This +is for a possible race condition during an automatic failover, where the mysql +client gets connected to a read-only replica after the failover. + +Note that this should be a fairly rare case, as an automatic failover normally +happens when the primary is down, and the race condition shouldn't happen +unless it comes back up online as soon as the failover is kicked off. On the +other hand, when this happens, a MySQL application can get stuck on a +read-only connection until restarted. It is however fairly easy to reproduce, +for example, using a manual failover on AWS Aurora's MySQL-compatible cluster. + +If you are not relying on read-only transactions to reject writes that aren't +supposed to happen, setting this on some MySQL providers (such as AWS Aurora) +is safer for failovers. + +Note that ERROR 1290 can be returned for a `read-only` server and this option will +cause a retry for that error. However the same error number is used for some +other cases. You should ensure your application will never cause an ERROR 1290 +except for `read-only` mode when enabling this option. + + +##### `serverPubKey` + +``` +Type: string +Valid Values: +Default: none +``` + +Server public keys can be registered with [`mysql.RegisterServerPubKey`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterServerPubKey), which can then be used by the assigned name in the DSN. +Public keys are used to transmit encrypted data, e.g. for authentication. +If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required. + + +##### `timeout` + +``` +Type: duration +Default: OS default +``` + +Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + + +##### `tls` + +``` +Type: bool / string +Valid Values: true, false, skip-verify, preferred, +Default: false +``` + +`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). + + +##### `writeTimeout` + +``` +Type: duration +Default: 0 +``` + +I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + + +##### System Variables + +Any other parameters are interpreted as system variables: + * `=`: `SET =` + * `=`: `SET =` + * `=%27%27`: `SET =''` + +Rules: +* The values for string variables must be quoted with `'`. +* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! + (which implies values of string variables must be wrapped with `%27`). + +Examples: + * `autocommit=1`: `SET autocommit=1` + * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` + * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` + + +#### Examples +``` +user@unix(/path/to/socket)/dbname +``` + +``` +root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local +``` + +``` +user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true +``` + +Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html): +``` +user:password@/dbname?sql_mode=TRADITIONAL +``` + +TCP via IPv6: +``` +user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci +``` + +TCP on a remote host, e.g. Amazon RDS: +``` +id:password@tcp(your-amazonaws-uri.com:3306)/dbname +``` + +Google Cloud SQL on App Engine (First Generation MySQL Server): +``` +user@cloudsql(project-id:instance-name)/dbname +``` + +Google Cloud SQL on App Engine (Second Generation MySQL Server): +``` +user@cloudsql(project-id:regionname:instance-name)/dbname +``` + +TCP using default port (3306) on localhost: +``` +user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped +``` + +Use the default protocol (tcp) and host (localhost:3306): +``` +user:password@/dbname +``` + +No Database preselected: +``` +user:password@/ +``` + + +### Connection pool and timeouts +The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. + +## `ColumnType` Support +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. + +## `context.Context` Support +Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. +See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. + + +### `LOAD DATA LOCAL INFILE` support +For this feature you need direct access to the package. Therefore you must change the import path (no `_`): +```go +import "github.com/go-sql-driver/mysql" +``` + +Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). + +To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. + +See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. + + +### `time.Time` support +The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. + +However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. + +**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). + +Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. + + +### Unicode support +Since version 1.1 Go-MySQL-Driver automatically uses the collation `utf8_general_ci` by default. + +Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. + +Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. + +See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. + +## Testing / Development +To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. + +Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. +If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls). + +See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details. + +--------------------------------------- + +## License +Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE) + +Mozilla summarizes the license scope as follows: +> MPL: The copyleft applies to any files containing MPLed code. + + +That means: + * You can **use** the **unchanged** source code both in private and commercially. + * When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). + * You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. + +Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license. + +You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). + +![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") + diff --git a/vendor/github.com/go-sql-driver/mysql/appengine.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/appengine.go similarity index 60% rename from vendor/github.com/go-sql-driver/mysql/appengine.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/appengine.go index 565614e..914e662 100644 --- a/vendor/github.com/go-sql-driver/mysql/appengine.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/appengine.go @@ -11,9 +11,15 @@ package mysql import ( - "appengine/cloudsql" + "context" + "net" + + "google.golang.org/appengine/cloudsql" ) func init() { - RegisterDial("cloudsql", cloudsql.Dial) + RegisterDialContext("cloudsql", func(_ context.Context, instance string) (net.Conn, error) { + // XXX: the cloudsql driver still does not export a Context-aware dialer. + return cloudsql.Dial(instance) + }) } diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/auth.go new file mode 100644 index 0000000..fec7040 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/auth.go @@ -0,0 +1,422 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "sync" +) + +// server pub keys registry +var ( + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey +) + +// RegisterServerPubKey registers a server RSA public key which can be used to +// send data in a secure manner to the server without receiving the public key +// in a potentially insecure way from the server first. +// Registered keys can afterwards be used adding serverPubKey= to the DSN. +// +// Note: The provided rsa.PublicKey instance is exclusively owned by the driver +// after registering it and may not be modified. +// +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } +// +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } +// +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } +// +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } +// +func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry == nil { + serverPubKeyRegistry = make(map[string]*rsa.PublicKey) + } + + serverPubKeyRegistry[name] = pubKey + serverPubKeyLock.Unlock() +} + +// DeregisterServerPubKey removes the public key registered with the given name. +func DeregisterServerPubKey(name string) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry != nil { + delete(serverPubKeyRegistry, name) + } + serverPubKeyLock.Unlock() +} + +func getServerPubKey(name string) (pubKey *rsa.PublicKey) { + serverPubKeyLock.RLock() + if v, ok := serverPubKeyRegistry[name]; ok { + pubKey = v + } + serverPubKeyLock.RUnlock() + return +} + +// Hash password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +// Pseudo random number generator +func newMyRnd(seed1, seed2 uint32) *myRnd { + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } +} + +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) +} + +// Generate binary hash from byte string using insecure pre 4.1 method +func pwHash(password []byte) (result [2]uint32) { + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 + + for _, c := range password { + // skip spaces and tabs in password + if c == ' ' || c == '\t' { + continue + } + + tmp = uint32(c) + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] + add += tmp + } + + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + + return +} + +// Hash password using insecure pre 4.1 method +func scrambleOldPassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + scramble = scramble[:8] + + hashPw := pwHash([]byte(password)) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} + +// Hash password using 4.1+ method (SHA1) +func scramblePassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write([]byte(password)) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} + +func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { + enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) + if err != nil { + return err + } + return mc.writeAuthSwitchPacket(enc) +} + +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { + switch plugin { + case "caching_sha2_password": + authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + return authResp, nil + + case "mysql_old_password": + if !mc.cfg.AllowOldPasswords { + return nil, ErrOldPassword + } + // Note: there are edge cases where this should work but doesn't; + // this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil + + case "mysql_clear_password": + if !mc.cfg.AllowCleartextPasswords { + return nil, ErrCleartextPassword + } + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + return append([]byte(mc.cfg.Passwd), 0), nil + + case "mysql_native_password": + if !mc.cfg.AllowNativePasswords { + return nil, ErrNativePassword + } + // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // Native password authentication only need and will need 20-byte challenge. + authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + return authResp, nil + + case "sha256_password": + if len(mc.cfg.Passwd) == 0 { + return []byte{0}, nil + } + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + return append([]byte(mc.cfg.Passwd), 0), nil + } + + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + return []byte{1}, nil + } + + // encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + return enc, err + + default: + errLog.Print("unknown auth plugin:", plugin) + return nil, ErrUnknownPlugin + } +} + +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { + // Read Result Packet + authData, newPlugin, err := mc.readAuthResult() + if err != nil { + return err + } + + // handle auth plugin switch, if requested + if newPlugin != "" { + // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is + // sent and we have to keep using the cipher sent in the init packet. + if authData == nil { + authData = oldAuthData + } else { + // copy data from read buffer to owned slice + copy(oldAuthData, authData) + } + + plugin = newPlugin + + authResp, err := mc.auth(authData, plugin) + if err != nil { + return err + } + if err = mc.writeAuthSwitchPacket(authResp); err != nil { + return err + } + + // Read Result Packet + authData, newPlugin, err = mc.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if newPlugin != "" { + return ErrMalformPkt + } + } + + switch plugin { + + // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + case "caching_sha2_password": + switch len(authData) { + case 0: + return nil // auth successful + case 1: + switch authData[0] { + case cachingSha2PasswordFastAuthSuccess: + if err = mc.readResultOK(); err == nil { + return nil // auth successful + } + + case cachingSha2PasswordPerformFullAuthentication: + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + if err != nil { + return err + } + } else { + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } + data[4] = cachingSha2PasswordRequestPublicKey + mc.writePacket(data) + + // parse public key + if data, err = mc.readPacket(); err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + pubKey = pkix.(*rsa.PublicKey) + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pubKey) + if err != nil { + return err + } + } + return mc.readResultOK() + + default: + return ErrMalformPkt + } + default: + return ErrMalformPkt + } + + case "sha256_password": + switch len(authData) { + case 0: + return nil // auth successful + default: + block, _ := pem.Decode(authData) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + return mc.readResultOK() + } + + default: + return nil // auth successful + } + + return err +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/auth_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/auth_test.go new file mode 100644 index 0000000..1920ef3 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/auth_test.go @@ -0,0 +1,1330 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "testing" +) + +var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" + + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAol0Z8G8U+25Btxk/g/fm\n" + + "UAW/wEKjQCTjkibDE4B+qkuWeiumg6miIRhtilU6m9BFmLQSy1ltYQuu4k17A4tQ\n" + + "rIPpOQYZges/qsDFkZh3wyK5jL5WEFVdOasf6wsfszExnPmcZS4axxoYJfiuilrN\n" + + "hnwinBAqfi3S0sw5MpSI4Zl1AbOrHG4zDI62Gti2PKiMGyYDZTS9xPrBLbN95Kby\n" + + "FFclQLEzA9RJcS1nHFsWtRgHjGPhhjCQxEm9NQ1nePFhCfBfApyfH1VM2VCOQum6\n" + + "Ci9bMuHWjTjckC84mzF99kOxOWVU7mwS6gnJqBzpuz8t3zq8/iQ2y7QrmZV+jTJP\n" + + "WQIDAQAB\n" + + "-----END PUBLIC KEY-----\n") + +var testPubKeyRSA *rsa.PublicKey + +func init() { + block, _ := pem.Decode(testPubKey) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + testPubKeyRSA = pub.(*rsa.PublicKey) +} + +func TestScrambleOldPass(t *testing.T) { + scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2} + vectors := []struct { + pass string + out string + }{ + {" pass", "47575c5a435b4251"}, + {"pass ", "47575c5a435b4251"}, + {"123\t456", "575c47505b5b5559"}, + {"C0mpl!ca ted#PASS123", "5d5d554849584a45"}, + } + for _, tuple := range vectors { + ours := scrambleOldPassword(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed old password %q", tuple.pass) + } + } +} + +func TestScrambleSHA256Pass(t *testing.T) { + scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21} + vectors := []struct { + pass string + out string + }{ + {"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"}, + {"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"}, + } + for _, tuple := range vectors { + ours := scrambleSHA256Password(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed SHA256 password %q", tuple.pass) + } + } +} + +func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56, + 139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15, + 184, 150, 26, 61, 57, 235} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 3, // Fast Auth Success + 7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // pub key response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthFastCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowNativePasswords = false + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthFastNativePassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252, + 172, 50, 211, 192, 240, 164, 26, 48, 207, 45} + if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{1} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + // hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + + // unset TLS config to prevent the actual establishment of a TLS wrapper + mc.cfg.tls = nil + + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 6, 1}, testPubKey...), + + // OK + {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 4 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Pub Key Request + 1, 0, 0, 5, 2, + + // 3. Packet: Encrypted Password + 0, 1, 0, 7, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {2, 0, 0, 4, 1, 4}, // Perform Full Authentication + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Cleartext password + 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + conn.maxReads = 1 + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthSwitchCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = false + + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + conn.maxReads = 1 + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthSwitchNativePassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103, + 21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +func TestOldAuthSwitchNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +func TestAuthSwitchOldPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +func TestOldAuthSwitch(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} +func TestAuthSwitchOldPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +func TestOldAuthSwitchPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // OldAuthSwitch request. + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Empty Password + 1, 0, 0, 3, 0, + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Pub Key Request + 1, 0, 0, 3, 1, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Encrypted Password + 0, 1, 0, 3, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Cleartext Password + 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/benchmark_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/benchmark_test.go new file mode 100644 index 0000000..3e25a3b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/benchmark_test.go @@ -0,0 +1,373 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "fmt" + "math" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type TB testing.B + +func (tb *TB) check(err error) { + if err != nil { + tb.Fatal(err) + } +} + +func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB { + tb.check(err) + return db +} + +func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows { + tb.check(err) + return rows +} + +func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { + tb.check(err) + return stmt +} + +func initDB(b *testing.B, queries ...string) *sql.DB { + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + for _, query := range queries { + if _, err := db.Exec(query); err != nil { + b.Fatalf("error on %q: %v", query, err) + } + } + return db +} + +const concurrencyLevel = 10 + +func BenchmarkQuery(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + var got string + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Errorf("query = %q; want one", got) + wg.Done() + return + } + } + }() + } +} + +func BenchmarkExec(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := tb.checkDB(sql.Open("mysql", dsn)) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("DO 1")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + if _, err := stmt.Exec(); err != nil { + b.Fatal(err.Error()) + } + } + }() + } +} + +// data, but no db writes +var roundtripSample []byte + +func initRoundtripBenchmarks() ([]byte, int, int) { + if roundtripSample == nil { + roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024)) + } + return roundtripSample, 16, len(roundtripSample) +} + +func BenchmarkRoundtripTxt(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + sampleString := string(sample) + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + b.StartTimer() + var result string + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sampleString[0:length] + rows := tb.checkRows(db.Query(`SELECT "` + test + `"`)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if result != test { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkRoundtripBin(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + stmt := tb.checkStmt(db.Prepare("SELECT ?")) + defer stmt.Close() + b.StartTimer() + var result sql.RawBytes + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sample[0:length] + rows := tb.checkRows(stmt.Query(test)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if !bytes.Equal(result, test) { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkInterpolation(b *testing.B) { + mc := &mysqlConn{ + cfg: &Config{ + InterpolateParams: true, + Loc: time.UTC, + }, + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + buf: newBuffer(nil), + } + + args := []driver.Value{ + int64(42424242), + float64(math.Pi), + false, + time.Unix(1423411542, 807015000), + []byte("bytes containing special chars ' \" \a \x00"), + "string containing special chars ' \" \a \x00", + } + q := "SELECT ?, ?, ?, ?, ?, ?" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := mc.interpolateParams(q, args) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var got string + for pb.Next() { + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Fatalf("query = %q; want one", got) + } + } + }) +} + +func BenchmarkQueryContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := stmt.ExecContext(ctx); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkExecContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes. +// "size=" means size of each blobs. +func BenchmarkQueryRawBytes(b *testing.B) { + var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} + db := initDB(b, + "DROP TABLE IF EXISTS bench_rawbytes", + "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", + ) + defer db.Close() + + blob := make([]byte, sizes[len(sizes)-1]) + for i := range blob { + blob[i] = 42 + } + for i := 0; i < 100; i++ { + _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) + if err != nil { + b.Fatal(err) + } + } + + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + b.ReportAllocs() + b.ResetTimer() + + for j := 0; j < b.N; j++ { + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + if err != nil { + b.Fatal(err) + } + nrows := 0 + for rows.Next() { + var buf sql.RawBytes + err := rows.Scan(&buf) + if err != nil { + b.Fatal(err) + } + if len(buf) != s { + b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + } + nrows++ + } + rows.Close() + if nrows != 100 { + b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) + } + } + }) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/buffer.go new file mode 100644 index 0000000..0774c5c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -0,0 +1,182 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "io" + "net" + "time" +) + +const defaultBufSize = 4096 +const maxCachedBufSize = 256 * 1024 + +// A buffer which is used for both reading and writing. +// This is possible since communication on each connection is synchronous. +// In other words, we can't write and read simultaneously on the same connection. +// The buffer is similar to bufio.Reader / Writer but zero-copy-ish +// Also highly optimized for this particular use case. +// This buffer is backed by two byte slices in a double-buffering scheme +type buffer struct { + buf []byte // buf is a byte buffer who's length and capacity are equal. + nc net.Conn + idx int + length int + timeout time.Duration + dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer + flipcnt uint // flipccnt is the current buffer counter for double-buffering +} + +// newBuffer allocates and returns a new buffer. +func newBuffer(nc net.Conn) buffer { + fg := make([]byte, defaultBufSize) + return buffer{ + buf: fg, + nc: nc, + dbuf: [2][]byte{fg, nil}, + } +} + +// flip replaces the active buffer with the background buffer +// this is a delayed flip that simply increases the buffer counter; +// the actual flip will be performed the next time we call `buffer.fill` +func (b *buffer) flip() { + b.flipcnt += 1 +} + +// fill reads into the buffer until at least _need_ bytes are in it +func (b *buffer) fill(need int) error { + n := b.length + // fill data into its double-buffering target: if we've called + // flip on this buffer, we'll be copying to the background buffer, + // and then filling it with network data; otherwise we'll just move + // the contents of the current buffer to the front before filling it + dest := b.dbuf[b.flipcnt&1] + + // grow buffer if necessary to fit the whole packet. + if need > len(dest) { + // Round up to the next multiple of the default size + dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + + // if the allocated buffer is not too large, move it to backing storage + // to prevent extra allocations on applications that perform large reads + if len(dest) <= maxCachedBufSize { + b.dbuf[b.flipcnt&1] = dest + } + } + + // if we're filling the fg buffer, move the existing data to the start of it. + // if we're filling the bg buffer, copy over the data + if n > 0 { + copy(dest[:n], b.buf[b.idx:]) + } + + b.buf = dest + b.idx = 0 + + for { + if b.timeout > 0 { + if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + return err + } + } + + nn, err := b.nc.Read(b.buf[n:]) + n += nn + + switch err { + case nil: + if n < need { + continue + } + b.length = n + return nil + + case io.EOF: + if n >= need { + b.length = n + return nil + } + return io.ErrUnexpectedEOF + + default: + return err + } + } +} + +// returns next N bytes from buffer. +// The returned slice is only guaranteed to be valid until the next read +func (b *buffer) readNext(need int) ([]byte, error) { + if b.length < need { + // refill + if err := b.fill(need); err != nil { + return nil, err + } + } + + offset := b.idx + b.idx += need + b.length -= need + return b.buf[offset:b.idx], nil +} + +// takeBuffer returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeBuffer(length int) ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + + // test (cheap) general case first + if length <= cap(b.buf) { + return b.buf[:length], nil + } + + if length < maxPacketSize { + b.buf = make([]byte, length) + return b.buf, nil + } + + // buffer is larger than we want to store. + return make([]byte, length), nil +} + +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + return b.buf[:length], nil +} + +// takeCompleteBuffer returns the complete existing buffer. +// This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeCompleteBuffer() ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { + if b.length > 0 { + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] + } + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/collations.go new file mode 100644 index 0000000..8d2b556 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/collations.go @@ -0,0 +1,265 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +const defaultCollation = "utf8mb4_general_ci" +const binaryCollation = "binary" + +// A list of available collations mapped to the internal ID. +// To update this map use the following MySQL query: +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID +// +// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. +// +// ucs2, utf16, and utf32 can't be used for connection charset. +// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset +// They are commented out to reduce this map. +var collations = map[string]byte{ + "big5_chinese_ci": 1, + "latin2_czech_cs": 2, + "dec8_swedish_ci": 3, + "cp850_general_ci": 4, + "latin1_german1_ci": 5, + "hp8_english_ci": 6, + "koi8r_general_ci": 7, + "latin1_swedish_ci": 8, + "latin2_general_ci": 9, + "swe7_swedish_ci": 10, + "ascii_general_ci": 11, + "ujis_japanese_ci": 12, + "sjis_japanese_ci": 13, + "cp1251_bulgarian_ci": 14, + "latin1_danish_ci": 15, + "hebrew_general_ci": 16, + "tis620_thai_ci": 18, + "euckr_korean_ci": 19, + "latin7_estonian_cs": 20, + "latin2_hungarian_ci": 21, + "koi8u_general_ci": 22, + "cp1251_ukrainian_ci": 23, + "gb2312_chinese_ci": 24, + "greek_general_ci": 25, + "cp1250_general_ci": 26, + "latin2_croatian_ci": 27, + "gbk_chinese_ci": 28, + "cp1257_lithuanian_ci": 29, + "latin5_turkish_ci": 30, + "latin1_german2_ci": 31, + "armscii8_general_ci": 32, + "utf8_general_ci": 33, + "cp1250_czech_cs": 34, + //"ucs2_general_ci": 35, + "cp866_general_ci": 36, + "keybcs2_general_ci": 37, + "macce_general_ci": 38, + "macroman_general_ci": 39, + "cp852_general_ci": 40, + "latin7_general_ci": 41, + "latin7_general_cs": 42, + "macce_bin": 43, + "cp1250_croatian_ci": 44, + "utf8mb4_general_ci": 45, + "utf8mb4_bin": 46, + "latin1_bin": 47, + "latin1_general_ci": 48, + "latin1_general_cs": 49, + "cp1251_bin": 50, + "cp1251_general_ci": 51, + "cp1251_general_cs": 52, + "macroman_bin": 53, + //"utf16_general_ci": 54, + //"utf16_bin": 55, + //"utf16le_general_ci": 56, + "cp1256_general_ci": 57, + "cp1257_bin": 58, + "cp1257_general_ci": 59, + //"utf32_general_ci": 60, + //"utf32_bin": 61, + //"utf16le_bin": 62, + "binary": 63, + "armscii8_bin": 64, + "ascii_bin": 65, + "cp1250_bin": 66, + "cp1256_bin": 67, + "cp866_bin": 68, + "dec8_bin": 69, + "greek_bin": 70, + "hebrew_bin": 71, + "hp8_bin": 72, + "keybcs2_bin": 73, + "koi8r_bin": 74, + "koi8u_bin": 75, + "utf8_tolower_ci": 76, + "latin2_bin": 77, + "latin5_bin": 78, + "latin7_bin": 79, + "cp850_bin": 80, + "cp852_bin": 81, + "swe7_bin": 82, + "utf8_bin": 83, + "big5_bin": 84, + "euckr_bin": 85, + "gb2312_bin": 86, + "gbk_bin": 87, + "sjis_bin": 88, + "tis620_bin": 89, + //"ucs2_bin": 90, + "ujis_bin": 91, + "geostd8_general_ci": 92, + "geostd8_bin": 93, + "latin1_spanish_ci": 94, + "cp932_japanese_ci": 95, + "cp932_bin": 96, + "eucjpms_japanese_ci": 97, + "eucjpms_bin": 98, + "cp1250_polish_ci": 99, + //"utf16_unicode_ci": 101, + //"utf16_icelandic_ci": 102, + //"utf16_latvian_ci": 103, + //"utf16_romanian_ci": 104, + //"utf16_slovenian_ci": 105, + //"utf16_polish_ci": 106, + //"utf16_estonian_ci": 107, + //"utf16_spanish_ci": 108, + //"utf16_swedish_ci": 109, + //"utf16_turkish_ci": 110, + //"utf16_czech_ci": 111, + //"utf16_danish_ci": 112, + //"utf16_lithuanian_ci": 113, + //"utf16_slovak_ci": 114, + //"utf16_spanish2_ci": 115, + //"utf16_roman_ci": 116, + //"utf16_persian_ci": 117, + //"utf16_esperanto_ci": 118, + //"utf16_hungarian_ci": 119, + //"utf16_sinhala_ci": 120, + //"utf16_german2_ci": 121, + //"utf16_croatian_ci": 122, + //"utf16_unicode_520_ci": 123, + //"utf16_vietnamese_ci": 124, + //"ucs2_unicode_ci": 128, + //"ucs2_icelandic_ci": 129, + //"ucs2_latvian_ci": 130, + //"ucs2_romanian_ci": 131, + //"ucs2_slovenian_ci": 132, + //"ucs2_polish_ci": 133, + //"ucs2_estonian_ci": 134, + //"ucs2_spanish_ci": 135, + //"ucs2_swedish_ci": 136, + //"ucs2_turkish_ci": 137, + //"ucs2_czech_ci": 138, + //"ucs2_danish_ci": 139, + //"ucs2_lithuanian_ci": 140, + //"ucs2_slovak_ci": 141, + //"ucs2_spanish2_ci": 142, + //"ucs2_roman_ci": 143, + //"ucs2_persian_ci": 144, + //"ucs2_esperanto_ci": 145, + //"ucs2_hungarian_ci": 146, + //"ucs2_sinhala_ci": 147, + //"ucs2_german2_ci": 148, + //"ucs2_croatian_ci": 149, + //"ucs2_unicode_520_ci": 150, + //"ucs2_vietnamese_ci": 151, + //"ucs2_general_mysql500_ci": 159, + //"utf32_unicode_ci": 160, + //"utf32_icelandic_ci": 161, + //"utf32_latvian_ci": 162, + //"utf32_romanian_ci": 163, + //"utf32_slovenian_ci": 164, + //"utf32_polish_ci": 165, + //"utf32_estonian_ci": 166, + //"utf32_spanish_ci": 167, + //"utf32_swedish_ci": 168, + //"utf32_turkish_ci": 169, + //"utf32_czech_ci": 170, + //"utf32_danish_ci": 171, + //"utf32_lithuanian_ci": 172, + //"utf32_slovak_ci": 173, + //"utf32_spanish2_ci": 174, + //"utf32_roman_ci": 175, + //"utf32_persian_ci": 176, + //"utf32_esperanto_ci": 177, + //"utf32_hungarian_ci": 178, + //"utf32_sinhala_ci": 179, + //"utf32_german2_ci": 180, + //"utf32_croatian_ci": 181, + //"utf32_unicode_520_ci": 182, + //"utf32_vietnamese_ci": 183, + "utf8_unicode_ci": 192, + "utf8_icelandic_ci": 193, + "utf8_latvian_ci": 194, + "utf8_romanian_ci": 195, + "utf8_slovenian_ci": 196, + "utf8_polish_ci": 197, + "utf8_estonian_ci": 198, + "utf8_spanish_ci": 199, + "utf8_swedish_ci": 200, + "utf8_turkish_ci": 201, + "utf8_czech_ci": 202, + "utf8_danish_ci": 203, + "utf8_lithuanian_ci": 204, + "utf8_slovak_ci": 205, + "utf8_spanish2_ci": 206, + "utf8_roman_ci": 207, + "utf8_persian_ci": 208, + "utf8_esperanto_ci": 209, + "utf8_hungarian_ci": 210, + "utf8_sinhala_ci": 211, + "utf8_german2_ci": 212, + "utf8_croatian_ci": 213, + "utf8_unicode_520_ci": 214, + "utf8_vietnamese_ci": 215, + "utf8_general_mysql500_ci": 223, + "utf8mb4_unicode_ci": 224, + "utf8mb4_icelandic_ci": 225, + "utf8mb4_latvian_ci": 226, + "utf8mb4_romanian_ci": 227, + "utf8mb4_slovenian_ci": 228, + "utf8mb4_polish_ci": 229, + "utf8mb4_estonian_ci": 230, + "utf8mb4_spanish_ci": 231, + "utf8mb4_swedish_ci": 232, + "utf8mb4_turkish_ci": 233, + "utf8mb4_czech_ci": 234, + "utf8mb4_danish_ci": 235, + "utf8mb4_lithuanian_ci": 236, + "utf8mb4_slovak_ci": 237, + "utf8mb4_spanish2_ci": 238, + "utf8mb4_roman_ci": 239, + "utf8mb4_persian_ci": 240, + "utf8mb4_esperanto_ci": 241, + "utf8mb4_hungarian_ci": 242, + "utf8mb4_sinhala_ci": 243, + "utf8mb4_german2_ci": 244, + "utf8mb4_croatian_ci": 245, + "utf8mb4_unicode_520_ci": 246, + "utf8mb4_vietnamese_ci": 247, + "gb18030_chinese_ci": 248, + "gb18030_bin": 249, + "gb18030_unicode_520_ci": 250, + "utf8mb4_0900_ai_ci": 255, +} + +// A blacklist of collations which is unsafe to interpolate parameters. +// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. +var unsafeCollations = map[string]bool{ + "big5_chinese_ci": true, + "sjis_japanese_ci": true, + "gbk_chinese_ci": true, + "big5_bin": true, + "gb2312_bin": true, + "gbk_bin": true, + "sjis_bin": true, + "cp932_japanese_ci": true, + "cp932_bin": true, + "gb18030_chinese_ci": true, + "gb18030_bin": true, + "gb18030_unicode_520_ci": true, +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck.go new file mode 100644 index 0000000..cc47aa5 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck.go @@ -0,0 +1,53 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !windows,!appengine + +package mysql + +import ( + "errors" + "io" + "net" + "syscall" +) + +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(c net.Conn) error { + var ( + n int + err error + buff [1]byte + ) + + sconn, ok := c.(syscall.Conn) + if !ok { + return nil + } + rc, err := sconn.SyscallConn() + if err != nil { + return err + } + rerr := rc.Read(func(fd uintptr) bool { + n, err = syscall.Read(int(fd), buff[:]) + return true + }) + switch { + case rerr != nil: + return rerr + case n == 0 && err == nil: + return io.EOF + case n > 0: + return errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + return nil + default: + return err + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/appengine.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go similarity index 63% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/appengine.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go index 565614e..fd01f64 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/appengine.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go @@ -1,19 +1,17 @@ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build appengine +// +build windows appengine package mysql -import ( - "appengine/cloudsql" -) +import "net" -func init() { - RegisterDial("cloudsql", cloudsql.Dial) +func connCheck(c net.Conn) error { + return nil } diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck_test.go new file mode 100644 index 0000000..b7234b0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/conncheck_test.go @@ -0,0 +1,38 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10,!windows + +package mysql + +import ( + "testing" + "time" +) + +func TestStaleConnectionChecks(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("SET @@SESSION.wait_timeout = 2") + + if err := dbt.db.Ping(); err != nil { + dbt.Fatal(err) + } + + // wait for MySQL to close our connection + time.Sleep(3 * time.Second) + + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + dbt.Fatal(err) + } + }) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connection.go similarity index 51% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/connection.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connection.go index c3899de..565a548 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connection.go @@ -9,7 +9,10 @@ package mysql import ( + "context" + "database/sql" "database/sql/driver" + "io" "net" "strconv" "strings" @@ -19,17 +22,26 @@ import ( type mysqlConn struct { buf buffer netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. affectedRows uint64 insertId uint64 cfg *Config - maxPacketAllowed int + maxAllowedPacket int maxWriteSize int writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 parseTime bool - strict bool + reset bool // set when the Go SQL package calls ResetSession + + // for context support (Go 1.8+) + watching bool + watcher chan<- context.Context + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established @@ -62,22 +74,41 @@ func (mc *mysqlConn) handleParams() (err error) { return } +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.netConn == nil { + return mc.begin(false) +} + +func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + var q string + if readOnly { + q = "START TRANSACTION READ ONLY" + } else { + q = "START TRANSACTION" + } + err := mc.exec(q) if err == nil { return &mysqlTx{mc}, err } - - return nil, err + return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if mc.netConn != nil { + if !mc.closed.IsSet() { err = mc.writeCommandPacket(comQuit) } @@ -91,26 +122,39 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { - // Makes cleanup idempotent - if mc.netConn != nil { - if err := mc.netConn.Close(); err != nil { - errLog.Print(err) - } - mc.netConn = nil + if !mc.closed.TrySet(true) { + return } - mc.cfg = nil - mc.buf.nc = nil + + // Makes cleanup idempotent + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) error() error { + if mc.closed.IsSet() { + if err := mc.canceled.Value(); err != nil { + return err + } + return ErrInvalidConn + } + return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { - return nil, err + return nil, mc.markBadConn(err) } stmt := &mysqlStmt{ @@ -135,11 +179,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { - buf := mc.buf.takeCompleteBuffer() - if buf == nil { + // Number of ? should be same to len(args) + if strings.Count(query, "?") != len(args) { + return "", driver.ErrSkip + } + + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return "", driver.ErrBadConn + errLog.Print(err) + return "", ErrInvalidConn } buf = buf[:0] argPos := 0 @@ -164,6 +213,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin switch v := arg.(type) { case int64: buf = strconv.AppendInt(buf, v, 10) + case uint64: + // Handle uint64 explicitly because our custom ConvertValue emits unsigned values + buf = strconv.AppendUint(buf, v, 10) case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: @@ -241,7 +293,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - if len(buf)+4 > mc.maxPacketAllowed { + if len(buf)+4 > mc.maxAllowedPacket { return "", driver.ErrSkip } } @@ -252,7 +304,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -266,7 +318,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, err } query = prepared - args = nil } mc.affectedRows = 0 mc.insertId = 0 @@ -278,32 +329,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err insertId: int64(mc.insertId), }, err } - return nil, err + return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command - err := mc.writeCommandPacketStr(comQuery, query) - if err != nil { - return err + if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + return mc.markBadConn(err) } // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil && resLen > 0 { - if err = mc.readUntilEOF(); err != nil { + if err != nil { + return err + } + + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { return err } - err = mc.readUntilEOF() + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } } - return err + return mc.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.netConn == nil { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -317,7 +379,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, err } query = prepared - args = nil } // Send command err := mc.writeCommandPacketStr(comQuery, query) @@ -330,15 +391,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro rows.mc = mc if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } } + // Columns - rows.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } - return nil, err + return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable @@ -354,7 +422,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if err == nil { rows := new(textRows) rows.mc = mc - rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns @@ -370,3 +438,212 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } return nil, err } + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.canceled.Set(err) + mc.cleanup() +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return mc.markBadConn(err) + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. + if ctx.Done() == nil { + return nil + } + // When watcher is not alive, can't watch it. + if mc.watcher == nil { + return nil + } + + mc.watching = true + mc.watcher <- ctx + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan context.Context, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx context.Context + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + mc.reset = true + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connection_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connection_test.go new file mode 100644 index 0000000..19c17ff --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connection_test.go @@ -0,0 +1,175 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "errors" + "net" + "testing" +) + +func TestInterpolateParams(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT 42+'gopher'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + +func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +// We don't support placeholder in string literal for now. +// https://github.com/go-sql-driver/mysql/pull/490 +func TestInterpolateParamsPlaceholderInString(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) + // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +func TestInterpolateParamsUint64(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)}) + if err != nil { + t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q) + } + if q != "SELECT 42" { + t.Errorf("Expected uint64 interpolation to work, got q=%#v", q) + } +} + +func TestCheckNamedValue(t *testing.T) { + value := driver.NamedValue{Value: ^uint64(0)} + x := &mysqlConn{} + err := x.CheckNamedValue(&value) + + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if value.Value != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value) + } +} + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +} + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connector.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connector.go new file mode 100644 index 0000000..5aaaba4 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/connector.go @@ -0,0 +1,143 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "net" +) + +type connector struct { + cfg *Config // immutable private copy. +} + +// Connect implements driver.Connector interface. +// Connect returns a connection to the database. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: c.cfg, + } + mc.parseTime = mc.cfg.ParseTime + + // Connect to Server + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + mc.netConn, err = dial(ctx, mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + } + + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + errLog.Print("net.Error from Dial()': ", nerr.Error()) + return nil, driver.ErrBadConn + } + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + // Call startWatcher for context support (From Go 1.8) + mc.startWatcher() + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + authData, plugin, err := mc.readHandshakePacket() + if err != nil { + mc.cleanup() + return nil, err + } + + if plugin == "" { + plugin = defaultAuthPlugin + } + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = mc.handleAuthResult(authData, plugin); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +// Driver implements driver.Connector interface. +// Driver returns &MySQLDriver{}. +func (c *connector) Driver() driver.Driver { + return &MySQLDriver{} +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/const.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/const.go similarity index 84% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/const.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/const.go index 88cfff3..b1e6b85 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/const.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/const.go @@ -9,7 +9,9 @@ package mysql const ( - minProtocolVersion byte = 10 + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) @@ -18,10 +20,11 @@ const ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html const ( - iOK byte = 0x00 - iLocalInFile byte = 0xfb - iEOF byte = 0xfe - iERR byte = 0xff + iOK byte = 0x00 + iAuthMoreData byte = 0x01 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags @@ -87,8 +90,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + const ( - fieldTypeDecimal byte = iota + fieldTypeDecimal fieldType = iota fieldTypeTiny fieldTypeShort fieldTypeLong @@ -107,7 +112,7 @@ const ( fieldTypeBit ) const ( - fieldTypeJSON byte = iota + 0xf5 + fieldTypeJSON fieldType = iota + 0xf5 fieldTypeNewDecimal fieldTypeEnum fieldTypeSet @@ -161,3 +166,9 @@ const ( statusInTransReadonly statusSessionStateChanged ) + +const ( + cachingSha2PasswordRequestPublicKey = 2 + cachingSha2PasswordFastAuthSuccess = 3 + cachingSha2PasswordPerformFullAuthentication = 4 +) diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver.go new file mode 100644 index 0000000..1f9decf --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver.go @@ -0,0 +1,85 @@ +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package mysql provides a MySQL driver for Go's database/sql package. +// +// The driver should be used via the database/sql package: +// +// import "database/sql" +// import _ "github.com/go-sql-driver/mysql" +// +// db, err := sql.Open("mysql", "user:password@/dbname") +// +// See https://github.com/go-sql-driver/mysql#usage for details +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "net" + "sync" +) + +// MySQLDriver is exported to make the driver directly accessible. +// In general the driver is used via the database/sql package. +type MySQLDriver struct{} + +// DialFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDial +// +// Deprecated: users should register a DialContextFunc instead +type DialFunc func(addr string) (net.Conn, error) + +// DialContextFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDialContext +type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error) + +var ( + dialsLock sync.RWMutex + dials map[string]DialContextFunc +) + +// RegisterDialContext registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// The current context for the connection and its address is passed to the dial function. +func RegisterDialContext(net string, dial DialContextFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() + if dials == nil { + dials = make(map[string]DialContextFunc) + } + dials[net] = dial +} + +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +// +// Deprecated: users should call RegisterDialContext instead +func RegisterDial(network string, dial DialFunc) { + RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) { + return dial(addr) + }) +} + +// Open new Connection. +// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how +// the DSN string is formatted +func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + c := &connector{ + cfg: cfg, + } + return c.Connect(context.Background()) +} + +func init() { + sql.Register("mysql", &MySQLDriver{}) +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_go110.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_go110.go new file mode 100644 index 0000000..eb5a8fe --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_go110.go @@ -0,0 +1,37 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "database/sql/driver" +) + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + cfg = cfg.Clone() + // normalize the contents of cfg so calls to NewConnector have the same + // behavior as MySQLDriver.OpenConnector + if err := cfg.normalize(); err != nil { + return nil, err + } + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_go110_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_go110_test.go new file mode 100644 index 0000000..19a0e59 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_go110_test.go @@ -0,0 +1,137 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "net" + "testing" + "time" +) + +var _ driver.DriverContext = &MySQLDriver{} + +type dialCtxKey struct{} + +func TestConnectorObeysDialTimeouts(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + if !ctx.Value(dialCtxKey{}).(bool) { + return nil, fmt.Errorf("test error: query context is not propagated to our dialer") + } + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + ctx := context.WithValue(context.Background(), dialCtxKey{}, true) + + _, err = db.ExecContext(ctx, "DO 1") + if err != nil { + t.Fatal(err) + } +} + +func configForTests(t *testing.T) *Config { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mycnf := NewConfig() + mycnf.User = user + mycnf.Passwd = pass + mycnf.Addr = addr + mycnf.Net = prot + mycnf.DBName = dbname + return mycnf +} + +func TestNewConnector(t *testing.T) { + mycnf := configForTests(t) + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +type slowConnection struct { + net.Conn + slowdown time.Duration +} + +func (sc *slowConnection) Read(b []byte) (int, error) { + time.Sleep(sc.slowdown) + return sc.Conn.Read(b) +} + +type connectorHijack struct { + driver.Connector + connErr error +} + +func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + conn, cw.connErr = cw.Connector.Connect(ctx) + return conn, cw.connErr +} + +func TestConnectorTimeoutsDuringOpen(t *testing.T) { + RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, prot, addr) + if err != nil { + return nil, err + } + return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil + }) + + mycnf := configForTests(t) + mycnf.Net = "slowconn" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + hijack := &connectorHijack{Connector: conn} + + db := sql.OpenDB(hijack) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = db.ExecContext(ctx, "DO 1") + if err != context.DeadlineExceeded { + t.Fatalf("ExecContext should have timed out") + } + if hijack.connErr != context.DeadlineExceeded { + t.Fatalf("(*Connector).Connect should have timed out") + } +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_test.go new file mode 100644 index 0000000..3dee1ba --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/driver_test.go @@ -0,0 +1,2996 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "context" + "crypto/tls" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "net/url" + "os" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Ensure that all the driver interfaces are implemented +var ( + _ driver.Rows = &binaryRows{} + _ driver.Rows = &textRows{} +) + +var ( + user string + pass string + prot string + addr string + dbname string + dsn string + netAddr string + available bool +) + +var ( + tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) + sDate = "2012-06-14" + tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC) + sDateTime = "2011-11-20 21:27:37" + tDate0 = time.Time{} + sDate0 = "0000-00-00" + sDateTime0 = "0000-00-00 00:00:00" +) + +// See https://github.com/go-sql-driver/mysql/wiki/Testing +func init() { + // get environment variables + env := func(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue + } + user = env("MYSQL_TEST_USER", "root") + pass = env("MYSQL_TEST_PASS", "") + prot = env("MYSQL_TEST_PROT", "tcp") + addr = env("MYSQL_TEST_ADDR", "localhost:3306") + dbname = env("MYSQL_TEST_DBNAME", "gotest") + netAddr = fmt.Sprintf("%s(%s)", prot, addr) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname) + c, err := net.Dial(prot, addr) + if err == nil { + available = true + c.Close() + } +} + +type DBTest struct { + *testing.T + db *sql.DB +} + +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + +func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn += "&multiStatements=true" + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + } +} + +func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + db.Exec("DROP TABLE IF EXISTS test") + + dsn2 := dsn + "&interpolateParams=true" + var db2 *sql.DB + if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + db2, err = sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db2.Close() + } + + dsn3 := dsn + "&multiStatements=true" + var db3 *sql.DB + if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + db3, err = sql.Open("mysql", dsn3) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db3.Close() + } + + dbt := &DBTest{t, db} + dbt2 := &DBTest{t, db2} + dbt3 := &DBTest{t, db3} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + if db2 != nil { + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + } + if db3 != nil { + test(dbt3) + dbt3.db.Exec("DROP TABLE IF EXISTS test") + } + } +} + +func (dbt *DBTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) +} + +func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + res, err := dbt.db.Exec(query, args...) + if err != nil { + dbt.fail("exec", query, err) + } + return res +} + +func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + rows, err := dbt.db.Query(query, args...) + if err != nil { + dbt.fail("query", query, err) + } + return rows +} + +func maybeSkip(t *testing.T, err error, skipErrno uint16) { + mySQLErr, ok := err.(*MySQLError) + if !ok { + return + } + + if mySQLErr.Number == skipErrno { + t.Skipf("skipping test for error: %v", err) + } +} + +func TestEmptyQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // just a comment, no query + rows := dbt.mustQuery("--") + defer rows.Close() + // will hang before #255 + if rows.Next() { + dbt.Errorf("next on rows must be false") + } + }) +} + +func TestCRUD(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + + // Test for unexpected data + var out bool + rows := dbt.mustQuery("SELECT * FROM test") + if rows.Next() { + dbt.Error("unexpected data in empty table") + } + rows.Close() + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) + } + if id != 0 { + dbt.Fatalf("expected InsertId 0, got %d", id) + } + + // Read + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if true != out { + dbt.Errorf("true != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Update + res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check Update + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if false != out { + dbt.Errorf("false != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Delete + res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check for unexpected rows + res = dbt.mustExec("DELETE FROM test") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 0 { + dbt.Fatalf("expected 0 affected row, got %d", count) + } + }) +} + +func TestMultiQuery(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Update + res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Read + var out int + rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") + if rows.Next() { + rows.Scan(&out) + if 5 != out { + dbt.Errorf("5 != %d", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + }) +} + +func TestInt(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} + in := int64(42) + var out int64 + var rows *sql.Rows + + // SIGNED + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // UNSIGNED ZEROFILL + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s ZEROFILL: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat32(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + in := float32(42.23) + var out float32 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64Placeholder(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") + rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestString(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} + in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" + var out string + var rows *sql.Rows + + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // BLOB + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + + id := 2 + in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + + "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." + dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) + + err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + if err != nil { + dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) + } else if out != in { + dbt.Errorf("BLOB: %s != %s", in, out) + } + }) +} + +func TestRawBytes(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := []byte("aaa") + v2 := []byte("bbb") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 sql.RawBytes + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + // https://github.com/go-sql-driver/mysql/issues/765 + // Appending to RawBytes shouldn't overwrite next RawBytes. + o1 = append(o1, "xyzzy"...) + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + +type testValuer struct { + value string +} + +func (tv testValuer) Value() (driver.Value, error) { + return tv.value, nil +} + +func TestValuer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuer{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + +type testValuerWithValidation struct { + value string +} + +func (tv testValuerWithValidation) Value() (driver.Value, error) { + if len(tv.value) == 0 { + return nil, fmt.Errorf("Invalid string valuer. Value must not be empty") + } + + return tv.value, nil +} + +func TestValuerWithValidation(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuerWithValidation{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM testValuer") + defer rows.Close() + + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + dbt.Errorf("Failed to check valuer error") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + dbt.Errorf("Failed to check nil") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + dbt.Errorf("Failed to check not valuer") + } + + dbt.mustExec("DROP TABLE IF EXISTS testValuer") + }) +} + +type timeTests struct { + dbtype string + tlayout string + tests []timeTest +} + +type timeTest struct { + s string // leading "!": do not use t as value in queries + t time.Time +} + +type timeMode byte + +func (t timeMode) String() string { + switch t { + case binaryString: + return "binary:string" + case binaryTime: + return "binary:time.Time" + case textString: + return "text:string" + } + panic("unsupported timeMode") +} + +func (t timeMode) Binary() bool { + switch t { + case binaryString, binaryTime: + return true + } + return false +} + +const ( + binaryString timeMode = iota + binaryTime + textString +) + +func (t timeTest) genQuery(dbtype string, mode timeMode) string { + var inner string + if mode.Binary() { + inner = "?" + } else { + inner = `"%s"` + } + return `SELECT cast(` + inner + ` as ` + dbtype + `)` +} + +func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { + var rows *sql.Rows + query := t.genQuery(dbtype, mode) + switch mode { + case binaryString: + rows = dbt.mustQuery(query, t.s) + case binaryTime: + rows = dbt.mustQuery(query, t.t) + case textString: + query = fmt.Sprintf(query, t.s) + rows = dbt.mustQuery(query) + default: + panic("unsupported mode") + } + defer rows.Close() + var err error + if !rows.Next() { + err = rows.Err() + if err == nil { + err = fmt.Errorf("no data") + } + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + var dst interface{} + err = rows.Scan(&dst) + if err != nil { + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + switch val := dst.(type) { + case []uint8: + str := string(val) + if str == t.s { + return + } + if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s { + // a fix mainly for TravisCI: + // accept full microsecond resolution in result for DATETIME columns + // where the binary protocol was used + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, str, + ) + case time.Time: + if val == t.t { + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, val.Format(tlayout), + ) + default: + fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t}) + dbt.Errorf("%s [%s]: unhandled type %T (is '%v')", + dbtype, mode, + val, val, + ) + } +} + +func TestDateTime(t *testing.T) { + afterTime := func(t time.Time, d string) time.Time { + dur, err := time.ParseDuration(d) + if err != nil { + panic(err) + } + return t.Add(dur) + } + // NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests + format := "2006-01-02 15:04:05.999999" + t0 := time.Time{} + tstr0 := "0000-00-00 00:00:00.000000" + testcases := []timeTests{ + {"DATE", format[:10], []timeTest{ + {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, + {t: t0, s: tstr0[:10]}, + }}, + {"DATETIME", format[:19], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(0)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(1)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, + {t: t0, s: tstr0[:21]}, + }}, + {"DATETIME(6)", format, []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, + {t: t0, s: tstr0}, + }}, + {"TIME", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(0)", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(1)", format[11:21], []timeTest{ + {t: afterTime(t0, "12345600ms")}, + {s: "!-12:34:56.7"}, + {s: "!-838:59:58.9"}, + {s: "!838:59:58.9"}, + {t: t0, s: tstr0[11:21]}, + }}, + {"TIME(6)", format[11:], []timeTest{ + {t: afterTime(t0, "1234567890123000ns")}, + {s: "!-12:34:56.789012"}, + {s: "!-838:59:58.999999"}, + {s: "!838:59:58.999999"}, + {t: t0, s: tstr0[11:]}, + }}, + } + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + microsecsSupported := false + zeroDateSupported := false + var rows *sql.Rows + var err error + rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) + if err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) + if err == nil { + rows.Scan(&zeroDateSupported) + rows.Close() + } + for _, setups := range testcases { + if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" { + // skip fractional second tests if unsupported by server + continue + } + for _, setup := range setups.tests { + allowBinTime := true + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } else if setup.s[0] == '!' { + // skip tests using setup.t as source in queries + allowBinTime = false + // fix setup.s - remove the "!" + setup.s = setup.s[1:] + } + if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] { + // skip disallowed 0000-00-00 date + continue + } + setup.run(dbt, setups.dbtype, setups.tlayout, textString) + setup.run(dbt, setups.dbtype, setups.tlayout, binaryString) + if allowBinTime { + setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime) + } + } + } + }) + } +} + +func TestTimestampMicros(t *testing.T) { + format := "2006-01-02 15:04:05.999999" + f0 := format[:19] + f1 := format[:21] + f6 := format[:26] + runTests(t, dsn, func(dbt *DBTest) { + // check if microseconds are supported. + // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width + // and not precision. + // Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html + microsecsSupported := false + if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + if !microsecsSupported { + // skip test + return + } + _, err := dbt.db.Exec(` + CREATE TABLE test ( + value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', + value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', + value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' + )`, + ) + if err != nil { + dbt.Error(err) + } + defer dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) + var res0, res1, res6 string + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Errorf("test contained no selectable values") + } + err = rows.Scan(&res0, &res1, &res6) + if err != nil { + dbt.Error(err) + } + if res0 != f0 { + dbt.Errorf("expected %q, got %q", f0, res0) + } + if res1 != f1 { + dbt.Errorf("expected %q, got %q", f1, res1) + } + if res6 != f6 { + dbt.Errorf("expected %q, got %q", f6, res6) + } + }) +} + +func TestNULL(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + nullStmt, err := dbt.db.Prepare("SELECT NULL") + if err != nil { + dbt.Fatal(err) + } + defer nullStmt.Close() + + nonNullStmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + defer nonNullStmt.Close() + + // NullBool + var nb sql.NullBool + // Invalid + if err = nullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if nb.Valid { + dbt.Error("valid NullBool which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if !nb.Valid { + dbt.Error("invalid NullBool which should be valid") + } else if nb.Bool != true { + dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) + } + + // NullFloat64 + var nf sql.NullFloat64 + // Invalid + if err = nullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if nf.Valid { + dbt.Error("valid NullFloat64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if !nf.Valid { + dbt.Error("invalid NullFloat64 which should be valid") + } else if nf.Float64 != float64(1) { + dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) + } + + // NullInt64 + var ni sql.NullInt64 + // Invalid + if err = nullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if ni.Valid { + dbt.Error("valid NullInt64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if !ni.Valid { + dbt.Error("invalid NullInt64 which should be valid") + } else if ni.Int64 != int64(1) { + dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) + } + + // NullString + var ns sql.NullString + // Invalid + if err = nullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if ns.Valid { + dbt.Error("valid NullString which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if !ns.Valid { + dbt.Error("invalid NullString which should be valid") + } else if ns.String != `1` { + dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") + } + + // nil-bytes + var b []byte + // Read nil + if err = nullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil []byte which should be nil") + } + // Read non-nil + if err = nonNullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil []byte which should be non-nil") + } + // Insert nil + b = nil + success := false + if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + dbt.Fatal(err) + } + if !success { + dbt.Error("inserting []byte(nil) as NULL failed") + } + // Check input==output with input==nil + b = nil + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil echo from nil input") + } + // Check input==output with input!=nil + b = []byte("") + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil echo from non-nil input") + } + + // Insert NULL + dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + + dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) + + var out interface{} + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if out != nil { + dbt.Errorf("%v != nil", out) + } + } else { + dbt.Error("no data") + } + }) +} + +func TestUint64(t *testing.T) { + const ( + u0 = uint64(0) + uall = ^u0 + uhigh = uall >> 1 + utop = ^uhigh + s0 = int64(0) + sall = ^s0 + shigh = int64(uhigh) + stop = ^shigh + ) + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + row := stmt.QueryRow( + u0, uhigh, utop, uall, + s0, shigh, stop, sall, + ) + + var ua, ub, uc, ud uint64 + var sa, sb, sc, sd int64 + + err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd) + if err != nil { + dbt.Fatal(err) + } + switch { + case ua != u0, + ub != uhigh, + uc != utop, + ud != uall, + sa != s0, + sb != shigh, + sc != stop, + sd != sall: + dbt.Fatal("unexpected result value") + } + }) +} + +func TestLongData(t *testing.T) { + runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { + var maxAllowedPacketSize int + err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) + if err != nil { + dbt.Fatal(err) + } + maxAllowedPacketSize-- + + // don't get too ambitious + if maxAllowedPacketSize > 1<<25 { + maxAllowedPacketSize = 1 << 25 + } + + dbt.mustExec("CREATE TABLE test (value LONGBLOB)") + + in := strings.Repeat(`a`, maxAllowedPacketSize+1) + var out string + var rows *sql.Rows + + // Long text data + const nonDataQueryLen = 28 // length query w/o value + inS := in[:maxAllowedPacketSize-nonDataQueryLen] + dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if inS != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + dbt.Fatalf("LONGBLOB: no data") + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Long binary data + dbt.mustExec("INSERT INTO test VALUES(?)", in) + rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + if err = rows.Err(); err != nil { + dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error()) + } else { + dbt.Fatal("LONGBLOB: no data (err: )") + } + } + }) +} + +func TestLoadData(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + verifyLoadDataResult := func() { + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + dbt.Fatal(err.Error()) + } + + i := 0 + values := [4]string{ + "a string", + "a string containing a \t", + "a string containing a \n", + "a string containing both \t\n", + } + + var id int + var value string + + for rows.Next() { + i++ + err = rows.Scan(&id, &value) + if err != nil { + dbt.Fatal(err.Error()) + } + if i != id { + dbt.Fatalf("%d != %d", i, id) + } + if values[i-1] != value { + dbt.Fatalf("%q != %q", values[i-1], value) + } + } + err = rows.Err() + if err != nil { + dbt.Fatal(err.Error()) + } + + if i != 4 { + dbt.Fatalf("rows count mismatch. Got %d, want 4", i) + } + } + + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") + + // Local File + file, err := ioutil.TempFile("", "gotest") + defer os.Remove(file.Name()) + if err != nil { + dbt.Fatal(err) + } + RegisterLocalFile(file.Name()) + + // Try first with empty file + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + var count int + err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&count) + if err != nil { + dbt.Fatal(err.Error()) + } + if count != 0 { + dbt.Fatalf("unexpected row count: got %d, want 0", count) + } + + // Then fille File with data and try to load it + file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") + file.Close() + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + verifyLoadDataResult() + + // Try with non-existing file + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent file didn't fail") + } else if err.Error() != "local file 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Reader + RegisterReaderHandler("test", func() io.Reader { + file, err = os.Open(file.Name()) + if err != nil { + dbt.Fatal(err) + } + return file + }) + dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent Reader didn't fail") + } else if err.Error() != "Reader 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + }) +} + +func TestFoundRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + }) + runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 matched rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 3 { + dbt.Fatalf("Expected 3 matched rows, got %d", count) + } + }) +} + +func TestTLS(t *testing.T) { + tlsTestReq := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + if err == ErrNoTLS { + dbt.Skip("server does not support TLS") + } else { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") + defer rows.Close() + + var variable, value *sql.RawBytes + for rows.Next() { + if err := rows.Scan(&variable, &value); err != nil { + dbt.Fatal(err.Error()) + } + + if (*value == nil) || (len(*value) == 0) { + dbt.Fatalf("no Cipher") + } else { + dbt.Logf("Cipher: %s", *value) + } + } + } + tlsTestOpt := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + runTests(t, dsn+"&tls=preferred", tlsTestOpt) + runTests(t, dsn+"&tls=skip-verify", tlsTestReq) + + // Verify that registering / using a custom cfg works + RegisterTLSConfig("custom-skip-verify", &tls.Config{ + InsecureSkipVerify: true, + }) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) +} + +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + md := &MySQLDriver{} + conn, err := md.Open(dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("error closing connection: %s", err.Error()) + } + + defer func() { + if err := recover(); err != nil { + t.Errorf("panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != driver.ErrBadConn { + t.Errorf("unexpected error '%s', expected '%s'", + err.Error(), driver.ErrBadConn.Error()) + } +} + +func TestCharset(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mustSetCharset := func(charsetParam, expected string) { + runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() + + if !rows.Next() { + dbt.Fatalf("error getting connection charset: %s", rows.Err()) + } + + var got string + rows.Scan(&got) + + if got != expected { + dbt.Fatalf("expected connection charset %s but got %s", expected, got) + } + }) + } + + // non utf8 test + mustSetCharset("charset=ascii", "ascii") + + // when the first charset is invalid, use the second + mustSetCharset("charset=none,utf8", "utf8") + + // when the first charset is valid, use it + mustSetCharset("charset=ascii,utf8", "ascii") + mustSetCharset("charset=utf8,ascii", "utf8") +} + +func TestFailingCharset(t *testing.T) { + runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + // run query to really establish connection... + _, err := dbt.db.Exec("SELECT 1") + if err == nil { + dbt.db.Close() + t.Fatalf("connection must not succeed without a valid charset") + } + }) +} + +func TestCollation(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + defaultCollation := "utf8mb4_general_ci" + testCollations := []string{ + "", // do not set + defaultCollation, // driver default + "latin1_general_ci", + "binary", + "utf8_unicode_ci", + "cp1257_bin", + } + + for _, collation := range testCollations { + var expected, tdsn string + if collation != "" { + tdsn = dsn + "&collation=" + collation + expected = collation + } else { + tdsn = dsn + expected = defaultCollation + } + + runTests(t, tdsn, func(dbt *DBTest) { + var got string + if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { + dbt.Fatal(err) + } + + if got != expected { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } + }) + } +} + +func TestColumnsWithAlias(t *testing.T) { + runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1 AS A") + defer rows.Close() + cols, _ := rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A" { + t.Fatalf("expected column name \"A\", got \"%s\"", cols[0]) + } + + rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") + defer rows.Close() + cols, _ = rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A.one" { + t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0]) + } + }) +} + +func TestRawBytesResultExceedsBuffer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // defaultBufSize from buffer.go + expected := strings.Repeat("abc", defaultBufSize) + + rows := dbt.mustQuery("SELECT '" + expected + "'") + defer rows.Close() + if !rows.Next() { + dbt.Error("expected result, got none") + } + var result sql.RawBytes + rows.Scan(&result) + if expected != string(result) { + dbt.Error("result did not match expected value") + } + }) +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + + // Regression test for timezone handling + tzTest := func(dbt *DBTest) { + // Create table + dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") + + // Insert local time into database (should be converted) + usCentral, _ := time.LoadLocation("US/Central") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + dbt.mustExec("INSERT INTO test VALUE (?)", reftime) + + // Retrieve time from DB + rows := dbt.mustQuery("SELECT ts FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Fatal("did not get any rows out") + } + + var dbTime time.Time + err := rows.Scan(&dbTime) + if err != nil { + dbt.Fatal("Err", err) + } + + // Check that dates match + if reftime.Unix() != dbTime.Unix() { + dbt.Errorf("times do not match.\n") + dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(UTC)=%v\n", dbTime) + } + } + + for _, tz := range zones { + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) + } +} + +// Special cases + +func TestRowsClose(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows, err := dbt.db.Query("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + err = rows.Close() + if err != nil { + dbt.Fatal(err) + } + + if rows.Next() { + dbt.Fatal("unexpected row after rows.Close()") + } + + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + }) +} + +// dangling statements +// http://code.google.com/p/go/issues/detail?id=3865 +func TestCloseStmtBeforeRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + rows, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows.Close() + + err = stmt.Close() + if err != nil { + dbt.Fatal(err) + } + + if !rows.Next() { + dbt.Fatal("getting row failed") + } else { + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + + var out bool + err = rows.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + }) +} + +// It is valid to have multiple Rows for the same Stmt +// http://code.google.com/p/go/issues/detail?id=3734 +func TestStmtMultiRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") + if err != nil { + dbt.Fatal(err) + } + + rows1, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows1.Close() + + rows2, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows2.Close() + + var out bool + + // 1 + if !rows1.Next() { + dbt.Fatal("first rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + if !rows2.Next() { + dbt.Fatal("first rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + // 2 + if !rows1.Next() { + dbt.Fatal("second rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows1.Next() { + dbt.Fatal("unexpected row on rows1") + } + err = rows1.Close() + if err != nil { + dbt.Fatal(err) + } + } + + if !rows2.Next() { + dbt.Fatal("second rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows2.Next() { + dbt.Fatal("unexpected row on rows2") + } + err = rows2.Close() + if err != nil { + dbt.Fatal(err) + } + } + }) +} + +// Regression test for +// * more than 32 NULL parameters (issue 209) +// * more parameters than fit into the buffer (issue 201) +// * parameters * 64 > max_allowed_packet (issue 734) +func TestPreparedManyCols(t *testing.T) { + numParams := 65535 + runTests(t, dsn, func(dbt *DBTest) { + query := "SELECT ?" + strings.Repeat(",?", numParams-1) + stmt, err := dbt.db.Prepare(query) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + + // create more parameters than fit into the buffer + // which will take nil-values + params := make([]interface{}, numParams) + rows, err := stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + + // Create 0byte string which we can't send via STMT_LONG_DATA. + for i := 0; i < numParams; i++ { + params[i] = "" + } + rows, err = stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + }) +} + +func TestConcurrent(t *testing.T) { + if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { + t.Skip("MYSQL_TEST_CONCURRENT env var not set") + } + + runTests(t, dsn, func(dbt *DBTest) { + var max int + err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + dbt.Logf("testing up to %d concurrent connections \r\n", max) + + var remaining, succeeded int32 = int32(max), 0 + + var wg sync.WaitGroup + wg.Add(max) + + var fatalError string + var once sync.Once + fatalf := func(s string, vals ...interface{}) { + once.Do(func() { + fatalError = fmt.Sprintf(s, vals...) + }) + } + + for i := 0; i < max; i++ { + go func(id int) { + defer wg.Done() + + tx, err := dbt.db.Begin() + atomic.AddInt32(&remaining, -1) + + if err != nil { + if err.Error() != "Error 1040: Too many connections" { + fatalf("error on conn %d: %s", id, err.Error()) + } + return + } + + // keep the connection busy until all connections are open + for remaining > 0 { + if _, err = tx.Exec("DO 1"); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + } + + if err = tx.Commit(); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + + // everything went fine with this connection + atomic.AddInt32(&succeeded, 1) + }(i) + } + + // wait until all conections are open + wg.Wait() + + if fatalError != "" { + dbt.Fatal(fatalError) + } + + dbt.Logf("reached %d concurrent connections\r\n", succeeded) + }) +} + +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, driver.ErrBadConn) +} + +// Tests custom dial functions +func TestCustomDial(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // our custom dial function which justs wraps net.Dial here + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + if _, err = db.Exec("DO 1"); err != nil { + t.Fatalf("connection failed: %s", err.Error()) + } +} + +func TestSQLInjection(t *testing.T) { + createTest := func(arg string) func(dbt *DBTest) { + return func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (?)", 1) + + var v int + // NULL can't be equal to anything, the idea here is to inject query so it returns row + // This test verifies that escapeQuotes and escapeBackslash are working properly + err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) + if err == sql.ErrNoRows { + return // success, sql injection failed + } else if err == nil { + dbt.Errorf("sql injection successful with arg: %s", arg) + } else { + dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) + } + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("1 OR 1=1")) + runTests(t, testdsn, createTest("' OR '1'='1")) + } +} + +// Test if inserted data is correctly retrieved after being escaped +func TestInsertRetrieveEscapedData(t *testing.T) { + testData := func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") + + // All sequences that are escaped by escapeQuotes and escapeBackslash + v := "foo \x00\n\r\x1a\"'\\" + dbt.mustExec("INSERT INTO test VALUES (?)", v) + + var out string + err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if out != v { + dbt.Errorf("%q != %q", out, v) + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, testData) + } +} + +func TestUnixSocketAuthFail(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Save the current logger so we can restore it. + oldLogger := errLog + + // Set a new logger so we can capture its output. + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + newLogger := log.New(buffer, "prefix: ", 0) + SetLogger(newLogger) + + // Restore the logger. + defer SetLogger(oldLogger) + + // Make a new DSN that uses the MySQL socket file and a bad password, which + // we can make by simply appending any character to the real password. + badPass := pass + "x" + socket := "" + if prot == "unix" { + socket = addr + } else { + // Get socket file from MySQL. + err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) + if err != nil { + t.Fatalf("error on SELECT @@socket: %s", err.Error()) + } + } + t.Logf("socket: %s", socket) + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) + db, err := sql.Open("mysql", badDSN) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + // Connect to MySQL for real. This will cause an auth failure. + err = db.Ping() + if err == nil { + t.Error("expected Ping() to return an error") + } + + // The driver should not log anything. + if actual := buffer.String(); actual != "" { + t.Errorf("expected no output, got %q", actual) + } + }) +} + +// See Issue #422 +func TestInterruptBySignal(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + DROP PROCEDURE IF EXISTS test_signal; + CREATE PROCEDURE test_signal(ret INT) + BEGIN + SELECT ret; + SIGNAL SQLSTATE + '45001' + SET + MESSAGE_TEXT = "an error", + MYSQL_ERRNO = 45001; + END + `) + defer dbt.mustExec("DROP PROCEDURE test_signal") + + var val int + + // text protocol + rows, err := dbt.db.Query("CALL test_signal(42)") + if err != nil { + dbt.Fatalf("error on text query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + + // binary protocol + rows, err = dbt.db.Query("CALL test_signal(?)", 42) + if err != nil { + dbt.Fatalf("error on binary query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + }) +} + +func TestColumnsReusesSlice(t *testing.T) { + rows := mysqlRows{ + rs: resultSet{ + columns: []mysqlField{ + { + tableName: "test", + name: "A", + }, + { + tableName: "test", + name: "B", + }, + }, + }, + } + + allocs := testing.AllocsPerRun(1, func() { + cols := rows.Columns() + + if len(cols) != 2 { + t.Fatalf("expected 2 columns, got %d", len(cols)) + } + }) + + if allocs != 0 { + t.Fatalf("expected 0 allocations, got %d", int(allocs)) + } + + if rows.rs.columnNames == nil { + t.Fatalf("expected columnNames to be set, got nil") + } +} + +func TestRejectReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read-only. We didn't set the `rejectReadOnly` + // option, so any writes after this should fail. + _, err := dbt.db.Exec("SET SESSION TRANSACTION READ ONLY") + // Error 1193: Unknown system variable 'TRANSACTION' => skip test, + // MySQL server version is too old + maybeSkip(t, err, 1193) + if _, err := dbt.db.Exec("DROP TABLE test"); err == nil { + t.Fatalf("writing to DB in read-only session without " + + "rejectReadOnly did not error") + } + // Set the session back to read-write so runTests() can properly clean + // up the table `test`. + dbt.mustExec("SET SESSION TRANSACTION READ WRITE") + }) + + // Enable the `rejectReadOnly` option. + runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read only. Any writes after this should error on + // a driver.ErrBadConn, and cause `database/sql` to initiate a new + // connection. + dbt.mustExec("SET SESSION TRANSACTION READ ONLY") + // This would error, but `database/sql` should automatically retry on a + // new connection which is not read-only, and eventually succeed. + dbt.mustExec("DROP TABLE test") + }) +} + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} + +// See Issue #799 +func TestEmptyPassword(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) + db, err := sql.Open("mysql", dsn) + if err == nil { + defer db.Close() + err = db.Ping() + } + + if pass == "" { + if err != nil { + t.Fatal(err.Error()) + } + } else { + if err == nil { + t.Fatal("expected authentication error") + } + if !strings.HasPrefix(err.Error(), "Error 1045") { + t.Fatal(err.Error()) + } + } +} + +// static interface implementation checks of mysqlConn +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) + +// static interface implementation checks of mysqlStmt +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} +) + +// Ensure that all the driver interfaces are implemented +var ( + // _ driver.RowsColumnTypeLength = &binaryRows{} + // _ driver.RowsColumnTypeLength = &textRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &textRows{} + _ driver.RowsColumnTypeNullable = &binaryRows{} + _ driver.RowsColumnTypeNullable = &textRows{} + _ driver.RowsColumnTypePrecisionScale = &binaryRows{} + _ driver.RowsColumnTypePrecisionScale = &textRows{} + _ driver.RowsColumnTypeScanType = &binaryRows{} + _ driver.RowsColumnTypeScanType = &textRows{} + _ driver.RowsNextResultSet = &binaryRows{} + _ driver.RowsNextResultSet = &textRows{} +) + +func TestMultiResultSet(t *testing.T) { + type result struct { + values [][]int + columns []string + } + + // checkRows is a helper test function to validate rows containing 3 result + // sets with specific values and columns. The basic query would look like this: + // + // SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + // SELECT 0 UNION SELECT 1; + // SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + // + // to distinguish test cases the first string argument is put in front of + // every error or fatal message. + checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) { + expected := []result{ + { + values: [][]int{{1, 2}, {3, 4}}, + columns: []string{"col1", "col2"}, + }, + { + values: [][]int{{1, 2, 3}, {4, 5, 6}}, + columns: []string{"col1", "col2", "col3"}, + }, + } + + var res1 result + for rows.Next() { + var res [2]int + if err := rows.Scan(&res[0], &res[1]); err != nil { + dbt.Fatal(err) + } + res1.values = append(res1.values, res[:]) + } + + cols, err := rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res1.columns = cols + + if !reflect.DeepEqual(expected[0], res1) { + dbt.Error(desc, "want =", expected[0], "got =", res1) + } + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + // ignoring one result set + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + var res2 result + cols, err = rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res2.columns = cols + + for rows.Next() { + var res [3]int + if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { + dbt.Fatal(desc, err) + } + res2.values = append(res2.values, res[:]) + } + + if !reflect.DeepEqual(expected[1], res2) { + dbt.Error(desc, "want =", expected[1], "got =", res2) + } + + if rows.NextResultSet() { + dbt.Error(desc, "unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error(desc, err) + } + } + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery(`DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`) + defer rows.Close() + checkRows("query: ", rows, dbt) + }) + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + queries := []string{ + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + } + + defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss") + + for i, query := range queries { + dbt.mustExec(query) + + stmt, err := dbt.db.Prepare("CALL test_mrss()") + if err != nil { + dbt.Fatalf("%v (i=%d)", err, i) + } + defer stmt.Close() + + for j := 0; j < 2; j++ { + rows, err := stmt.Query() + if err != nil { + dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j) + } + checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt) + } + } + }) +} + +func TestMultiResultSetNoSelect(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("DO 1; DO 2;") + defer rows.Close() + + if rows.Next() { + dbt.Error("unexpected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +// tests if rows are set in a proper state if some results were ignored before +// calling rows.NextResultSet. +func TestSkipResults(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1, 2") + defer rows.Close() + + if !rows.Next() { + dbt.Error("expected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +func TestPingContext(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := dbt.db.PingContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + dbt.Error("expected error") + } else if err.Error() != "context canceled" { + dbt.Fatalf("unexpected error: %s", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQueryRow(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + // the first row will be succeed. + var v int + if !rows.Next() { + dbt.Fatalf("unexpected end") + } + if err := rows.Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + + cancel() + // make sure the driver receives the cancel request. + time.Sleep(100 * time.Millisecond) + + if rows.Next() { + dbt.Errorf("expected end, but not") + } + if err := rows.Err(); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelPrepare(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelStmtExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.ExecContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelStmtQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.QueryContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelBegin(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + tx, err := dbt.db.BeginTx(ctx, nil) + if err != nil { + dbt.Fatal(err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Transaction is canceled, so expect an error. + switch err := tx.Commit(); err { + case sql.ErrTxDone: + // because the transaction has already been rollbacked. + // the database/sql package watches ctx + // and rollbacks when ctx is canceled. + case context.Canceled: + // the database/sql package rollbacks on another goroutine, + // so the transaction may not be rollbacked depending on goroutine scheduling. + default: + dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) + } + + // Context is canceled, so cannot begin a transaction. + if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextBeginIsolationLevel(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx1, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + }) + if err != nil { + dbt.Fatal(err) + } + + tx2, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelReadCommitted, + }) + if err != nil { + dbt.Fatal(err) + } + + _, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if err != nil { + dbt.Fatal(err) + } + + var v int + row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Because writer transaction wasn't commited yet, it should be available + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + err = tx1.Commit() + if err != nil { + dbt.Fatal(err) + } + + row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Data written by writer transaction is already commited, it should be selectable + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + tx2.Commit() + }) +} + +func TestContextBeginReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: true, + }) + if _, ok := err.(*MySQLError); ok { + dbt.Skip("It seems that your MySQL does not support READ ONLY transactions") + return + } else if err != nil { + dbt.Fatal(err) + } + + // INSERT queries fail in a READ ONLY transaction. + _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if _, ok := err.(*MySQLError); !ok { + dbt.Errorf("expected MySQLError, got %v", err) + } + + // SELECT queries can be executed. + var v int + row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + if err := tx.Commit(); err != nil { + dbt.Fatal(err) + } + }) +} + +func TestRowsColumnTypes(t *testing.T) { + niNULL := sql.NullInt64{Int64: 0, Valid: false} + ni0 := sql.NullInt64{Int64: 0, Valid: true} + ni1 := sql.NullInt64{Int64: 1, Valid: true} + ni42 := sql.NullInt64{Int64: 42, Valid: true} + nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} + nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} + nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} + nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := NullTime{Time: time.Time{}, Valid: false} + rbNULL := sql.RawBytes(nil) + rb0 := sql.RawBytes("0") + rb42 := sql.RawBytes("42") + rbTest := sql.RawBytes("Test") + rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 + rbx0 := sql.RawBytes("\x00") + rbx42 := sql.RawBytes("\x42") + + var columns = []struct { + name string + fieldType string // type used when creating table schema + databaseTypeName string // actual type used by MySQL + scanType reflect.Type + nullable bool + precision int64 // 0 if not ok + scale int64 + valuesIn [3]string + valuesOut [3]interface{} + }{ + {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, + {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, + {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, + {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}}, + {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, + {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, + {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, + {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, + {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, + {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, + {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, + {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, + } + + schema := "" + values1 := "" + values2 := "" + values3 := "" + for _, column := range columns { + schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType) + values1 += column.valuesIn[0] + ", " + values2 += column.valuesIn[1] + ", " + values3 += column.valuesIn[2] + ", " + } + schema = schema[:len(schema)-2] + values1 = values1[:len(values1)-2] + values2 = values2[:len(values2)-2] + values3 = values3[:len(values3)-2] + + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } + + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] + + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } + + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } + + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) + } + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { + if !ok { + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) + } + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) + } + continue + } + } + + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) + } + } + } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) + } +} + +func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") + dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) + // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() + }) +} + +// TestRawBytesAreNotModified checks for a race condition that arises when a query context +// is canceled while a user is calling rows.Scan. This is a more stringent test than the one +// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using +// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit +// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers. +func TestRawBytesAreNotModified(t *testing.T) { + const blob = "abcdefghijklmnop" + const contextRaceIterations = 20 + const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row. + const insertRows = 4 + + var sqlBlobs = [2]string{ + strings.Repeat(blob, blobSize/len(blob)), + strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + for i := 0; i < insertRows; i++ { + dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) + if err != nil { + t.Fatal(err) + } + + var b int + var raw sql.RawBytes + for rows.Next() { + if err := rows.Scan(&b, &raw); err != nil { + t.Fatal(err) + } + + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) + + if before != after { + t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + } + } + rows.Close() + }() + } + }) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/dsn.go similarity index 67% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/dsn.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/dsn.go index 73138bc..1d9b4ab 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -10,11 +10,15 @@ package mysql import ( "bytes" + "crypto/rsa" "crypto/tls" "errors" "fmt" + "math/big" "net" "net/url" + "sort" + "strconv" "strings" "time" ) @@ -26,31 +30,122 @@ var ( errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") ) -// Config is a configuration parsed from a DSN string +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowNativePasswords bool // Allows the native password authentication method AllowOldPasswords bool // Allows the old insecure password method ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time - Strict bool // Return warnings as errors + RejectReadOnly bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, + AllowNativePasswords: true, + } +} + +func (cfg *Config) Clone() *Config { + cp := *cfg + if cp.tls != nil { + cp.tls = cfg.tls.Clone() + } + if len(cp.Params) > 0 { + cp.Params = make(map[string]string, len(cfg.Params)) + for k, v := range cfg.Params { + cp.Params[k] = v + } + } + if cfg.pubKey != nil { + cp.pubKey = &rsa.PublicKey{ + N: new(big.Int).Set(cfg.pubKey.N), + E: cfg.pubKey.E, + } + } + return &cp +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + switch cfg.TLSConfig { + case "false", "": + // don't set anything + case "true": + cfg.tls = &tls.Config{} + case "skip-verify", "preferred": + cfg.tls = &tls.Config{InsecureSkipVerify: true} + default: + cfg.tls = getTLSConfigClone(cfg.TLSConfig) + if cfg.tls == nil { + return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + } + } + + if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + + if cfg.ServerPubKey != "" { + cfg.pubKey = getServerPubKey(cfg.ServerPubKey) + if cfg.pubKey == nil { + return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey) + } + } + + return nil } // FormatDSN formats the given Config into a DSN string which can be passed to @@ -99,6 +194,15 @@ func (cfg *Config) FormatDSN() string { } } + if !cfg.AllowNativePasswords { + if hasParam { + buf.WriteString("&allowNativePasswords=false") + } else { + hasParam = true + buf.WriteString("?allowNativePasswords=false") + } + } + if cfg.AllowOldPasswords { if hasParam { buf.WriteString("&allowOldPasswords=true") @@ -183,15 +287,25 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.ReadTimeout.String()) } - if cfg.Strict { + if cfg.RejectReadOnly { if hasParam { - buf.WriteString("&strict=true") + buf.WriteString("&rejectReadOnly=true") } else { hasParam = true - buf.WriteString("?strict=true") + buf.WriteString("?rejectReadOnly=true") } } + if len(cfg.ServerPubKey) > 0 { + if hasParam { + buf.WriteString("&serverPubKey=") + } else { + hasParam = true + buf.WriteString("?serverPubKey=") + } + buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) + } + if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") @@ -222,9 +336,25 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.WriteTimeout.String()) } + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { + if hasParam { + buf.WriteString("&maxAllowedPacket=") + } else { + hasParam = true + buf.WriteString("?maxAllowedPacket=") + } + buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket)) + + } + // other params if cfg.Params != nil { - for param, value := range cfg.Params { + var params []string + for param := range cfg.Params { + params = append(params, param) + } + sort.Strings(params) + for _, param := range params { if hasParam { buf.WriteByte('&') } else { @@ -234,7 +364,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(param) buf.WriteByte('=') - buf.WriteString(url.QueryEscape(value)) + buf.WriteString(url.QueryEscape(cfg.Params[param])) } } @@ -244,10 +374,7 @@ func (cfg *Config) FormatDSN() string { // ParseDSN parses the DSN string to a Config func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = &Config{ - Loc: time.UTC, - Collation: defaultCollation, - } + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -315,28 +442,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { return nil, errInvalidDSNNoSlash } - if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { - return nil, errInvalidDSNUnsafeCollation + if err = cfg.normalize(); err != nil { + return nil, err } - - // Set default network if empty - if cfg.Net == "" { - cfg.Net = "tcp" - } - - // Set default address if empty - if cfg.Addr == "" { - switch cfg.Net { - case "tcp": - cfg.Addr = "127.0.0.1:3306" - case "unix": - cfg.Addr = "/tmp/mysql.sock" - default: - return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") - } - - } - return } @@ -351,7 +459,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool @@ -368,6 +475,14 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Use native password authentication + case "allowNativePasswords": + var isBool bool + cfg.AllowNativePasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // Use old authentication mode (pre MySQL 4.1) case "allowOldPasswords": var isBool bool @@ -441,14 +556,26 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } - // Strict mode - case "strict": + // Reject read-only connections + case "rejectReadOnly": var isBool bool - cfg.Strict, isBool = readBool(value) + cfg.RejectReadOnly, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + cfg.ServerPubKey = name + + // Strict mode + case "strict": + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + // Dial Timeout case "timeout": cfg.Timeout, err = time.ParseDuration(value) @@ -462,32 +589,17 @@ func parseDSNParams(cfg *Config, params string) (err error) { if isBool { if boolValue { cfg.TLSConfig = "true" - cfg.tls = &tls.Config{} } else { cfg.TLSConfig = "false" } - } else if vl := strings.ToLower(value); vl == "skip-verify" { + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl - cfg.tls = &tls.Config{InsecureSkipVerify: true} } else { name, err := url.QueryUnescape(value) if err != nil { return fmt.Errorf("invalid value for TLS config name: %v", err) } - - if tlsConfig, ok := tlsConfigRegister[name]; ok { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - tlsConfig.ServerName = host - } - } - - cfg.TLSConfig = name - cfg.tls = tlsConfig - } else { - return errors.New("invalid value / unknown config name: " + name) - } + cfg.TLSConfig = name } // I/O write Timeout @@ -496,7 +608,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } - + case "maxAllowedPacket": + cfg.MaxAllowedPacket, err = strconv.Atoi(value) + if err != nil { + return + } default: // lazy init if cfg.Params == nil { @@ -511,3 +627,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } + +func ensureHavePort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "3306") + } + return addr +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/dsn_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/dsn_test.go new file mode 100644 index 0000000..50dc293 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/dsn_test.go @@ -0,0 +1,415 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/tls" + "fmt" + "net/url" + "reflect" + "testing" + "time" +) + +var testDSNs = []struct { + in string + out *Config +}{{ + "username:password@protocol(address)/dbname?param=value", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true}, +}, { + "user@unix(/path/to/socket)/dbname?charset=utf8", + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"}, +}, { + "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, +}, { + "user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false}, +}, { + "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "@/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "user:p@/ssword@/", + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "unix/?arg=%2Fsome%2Fpath.ext", + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "tcp(127.0.0.1)/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "tcp(de:ad:be:ef::ca:fe)/dbname", + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, +} + +func TestDSNParser(t *testing.T) { + for i, tst := range testDSNs { + cfg, err := ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + } + + // pointer not static + cfg.tls = nil + + if !reflect.DeepEqual(cfg, tst.out) { + t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) + } + } +} + +func TestDSNParserInvalid(t *testing.T) { + var invalidDSNs = []string{ + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + "net()/", // unknown default addr + //"/dbname?arg=/some/unescaped/path", + } + + for i, tst := range invalidDSNs { + if _, err := ParseDSN(tst); err == nil { + t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) + } + } +} + +func TestDSNReformat(t *testing.T) { + for i, tst := range testDSNs { + dsn1 := tst.in + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Error(err.Error()) + continue + } + cfg1.tls = nil // pointer not static + res1 := fmt.Sprintf("%+v", cfg1) + + dsn2 := cfg1.FormatDSN() + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Error(err.Error()) + continue + } + cfg2.tls = nil // pointer not static + res2 := fmt.Sprintf("%+v", cfg2) + + if res1 != res2 { + t.Errorf("%d. %q does not match %q", i, res2, res1) + } + } +} + +func TestDSNServerPubKey(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + tst := baseDSN + "testKey" + cfg, err := ParseDSN(tst) + if err != nil { + t.Error(err.Error()) + } + + if cfg.ServerPubKey != "testKey" { + t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey) + } + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } + + // Key is missing + tst = baseDSN + "invalid_name" + cfg, err = ParseDSN(tst) + if err == nil { + t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } +} + +func TestDSNServerPubKeyQueryEscape(t *testing.T) { + const name = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name) + + RegisterServerPubKey(name, testPubKeyRSA) + defer DeregisterServerPubKey(name) + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } +} + +func TestDSNWithCustomTLS(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?tls=" + tlsCfg := tls.Config{} + + RegisterTLSConfig("utils_test", &tlsCfg) + defer DeregisterTLSConfig("utils_test") + + // Custom TLS is missing + tst := baseDSN + "invalid_tls" + cfg, err := ParseDSN(tst) + if err == nil { + t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } + + tst = baseDSN + "utils_test" + + // Custom TLS with a server name + name := "foohost" + tlsCfg.ServerName = name + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) + } + + // Custom TLS without a server name + name = "localhost" + tlsCfg.ServerName = "" + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) + } else if tlsCfg.ServerName != "" { + t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) + } +} + +func TestDSNTLSConfig(t *testing.T) { + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true" + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + dsn = "tcp(example.com)/?tls=true" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } +} + +func TestDSNWithCustomTLSQueryEscape(t *testing.T) { + const configKey = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) + name := "foohost" + tlsCfg := tls.Config{ServerName: name} + + RegisterTLSConfig(configKey, &tlsCfg) + defer DeregisterTLSConfig(configKey) + + cfg, err := ParseDSN(dsn) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) + } +} + +func TestDSNUnsafeCollation(t *testing.T) { + _, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") + if err != errInvalidDSNUnsafeCollation { + t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } +} + +func TestParamsAreSorted(t *testing.T) { + expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo" + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.InterpolateParams = true + cfg.Params = map[string]string{ + "quux": "loo", + "foobar": "baz", + } + actual := cfg.FormatDSN() + if actual != expected { + t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual) + } +} + +func TestCloneConfig(t *testing.T) { + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey" + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatal(err.Error()) + } + + cfg2 := cfg.Clone() + if cfg == cfg2 { + t.Errorf("Config.Clone did not create a separate config struct") + } + + if cfg2.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + cfg2.tls.ServerName = "example2.com" + if cfg.tls.ServerName == cfg2.tls.ServerName { + t.Errorf("changed cfg.tls.Server name should not propagate to original Config") + } + + if _, ok := cfg2.Params["foobar"]; !ok { + t.Errorf("cloned Config is missing custom params") + } + + delete(cfg2.Params, "foobar") + + if _, ok := cfg.Params["foobar"]; !ok { + t.Errorf("custom params in cloned Config should not propagate to original Config") + } + + if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) { + t.Errorf("public key in Config should be identical") + } +} + +func TestNormalizeTLSConfig(t *testing.T) { + tt := []struct { + tlsConfig string + want *tls.Config + }{ + {"", nil}, + {"false", nil}, + {"true", &tls.Config{ServerName: "myserver"}}, + {"skip-verify", &tls.Config{InsecureSkipVerify: true}}, + {"preferred", &tls.Config{InsecureSkipVerify: true}}, + {"test_tls_config", &tls.Config{ServerName: "myServerName"}}, + } + + RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"}) + defer func() { DeregisterTLSConfig("test_tls_config") }() + + for _, tc := range tt { + t.Run(tc.tlsConfig, func(t *testing.T) { + cfg := &Config{ + Addr: "myserver:3306", + TLSConfig: tc.tlsConfig, + } + + cfg.normalize() + + if cfg.tls == nil { + if tc.want != nil { + t.Fatal("wanted a tls config but got nil instead") + } + return + } + + if cfg.tls.ServerName != tc.want.ServerName { + t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')", + tc.want.ServerName, cfg.tls.ServerName) + } + if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify { + t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)", + tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify) + } + }) + } +} + +func BenchmarkParseDSN(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tst := range testDSNs { + if _, err := ParseDSN(tst.in); err != nil { + b.Error(err.Error()) + } + } + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/errors.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/errors.go similarity index 61% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/errors.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/errors.go index 1543a80..760782f 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/errors.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/errors.go @@ -9,10 +9,8 @@ package mysql import ( - "database/sql/driver" "errors" "fmt" - "io" "log" "os" ) @@ -22,14 +20,21 @@ var ( ErrInvalidConn = errors.New("invalid connection") ErrMalformPkt = errors.New("malformed packet") ErrNoTLS = errors.New("TLS requested but server does not support TLS") - ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") + ErrNativePassword = errors.New("this user requires mysql native password authentication.") + ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") ErrUnknownPlugin = errors.New("this authentication plugin is not supported") ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") ErrPktSync = errors.New("commands out of sync. You can't run this command now") ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") ErrBusyBuffer = errors.New("busy buffer") + + // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. + // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn + // to trigger a resend. + // See https://github.com/go-sql-driver/mysql/pull/302 + errBadConnNoWrite = errors.New("bad connection") ) var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) @@ -58,74 +63,3 @@ type MySQLError struct { func (me *MySQLError) Error() string { return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } - -// MySQLWarnings is an error type which represents a group of one or more MySQL -// warnings -type MySQLWarnings []MySQLWarning - -func (mws MySQLWarnings) Error() string { - var msg string - for i, warning := range mws { - if i > 0 { - msg += "\r\n" - } - msg += fmt.Sprintf( - "%s %s: %s", - warning.Level, - warning.Code, - warning.Message, - ) - } - return msg -} - -// MySQLWarning is an error type which represents a single MySQL warning. -// Warnings are returned in groups only. See MySQLWarnings -type MySQLWarning struct { - Level string - Code string - Message string -} - -func (mc *mysqlConn) getWarnings() (err error) { - rows, err := mc.Query("SHOW WARNINGS", nil) - if err != nil { - return - } - - var warnings = MySQLWarnings{} - var values = make([]driver.Value, 3) - - for { - err = rows.Next(values) - switch err { - case nil: - warning := MySQLWarning{} - - if raw, ok := values[0].([]byte); ok { - warning.Level = string(raw) - } else { - warning.Level = fmt.Sprintf("%s", values[0]) - } - if raw, ok := values[1].([]byte); ok { - warning.Code = string(raw) - } else { - warning.Code = fmt.Sprintf("%s", values[1]) - } - if raw, ok := values[2].([]byte); ok { - warning.Message = string(raw) - } else { - warning.Message = fmt.Sprintf("%s", values[0]) - } - - warnings = append(warnings, warning) - - case io.EOF: - return warnings - - default: - rows.Close() - return - } - } -} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/errors_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/errors_test.go new file mode 100644 index 0000000..96f9126 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/errors_test.go @@ -0,0 +1,42 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "log" + "testing" +) + +func TestErrorsSetLogger(t *testing.T) { + previous := errLog + defer func() { + errLog = previous + }() + + // set up logger + const expected = "prefix: test\n" + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + logger := log.New(buffer, "prefix: ", 0) + + // print + SetLogger(logger) + errLog.Print("test") + + // check result + if actual := buffer.String(); actual != expected { + t.Errorf("expected %q, got %q", expected, actual) + } +} + +func TestErrorsStrictIgnoreNotes(t *testing.T) { + runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) { + dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") + }) +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/fields.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/fields.go new file mode 100644 index 0000000..e1e2ece --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/fields.go @@ -0,0 +1,194 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "reflect" +) + +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte + charSet uint8 +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/infile.go similarity index 98% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/infile.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/infile.go index 0f975bb..273cb0b 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/infile.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/infile.go @@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { } // send content packets - if err == nil { + // if packetSize == 0, the Reader contains no data + if err == nil && packetSize > 0 { data := make([]byte, 4+packetSize) var n int for err == nil { diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/packets.go similarity index 70% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/packets.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/packets.go index 8d91665..30b3352 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/packets.go @@ -25,26 +25,23 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { - var payload []byte + var prevData []byte for { - // Read packet header + // read packet header data, err := mc.buf.readNext(4) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } - // Packet Length [24 bit] + // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - if pktLen < 1 { - errLog.Print(ErrMalformPkt) - mc.Close() - return nil, driver.ErrBadConn - } - - // Check Packet Sync [8 bit] + // check packet sync [8 bit] if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul @@ -53,26 +50,41 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } mc.sequence++ - // Read packet body [pktLen bytes] + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)-1 bytes long + if pktLen == 0 { + // there was no previous packet + if prevData == nil { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, ErrInvalidConn + } + + return prevData, nil + } + + // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } - isLastPacket := (pktLen < maxPacketSize) + // return data if this was the last packet + if pktLen < maxPacketSize { + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } - // Zero allocations for non-splitting packets - if isLastPacket && payload == nil { - return data, nil + return append(prevData, data...), nil } - payload = append(payload, data...) - - if isLastPacket { - return payload, nil - } + prevData = append(prevData, data...) } } @@ -80,10 +92,39 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 - if pktLen > mc.maxPacketAllowed { + if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.reset { + mc.reset = false + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + var err error + // If this connection has a ReadTimeout which we've been setting on + // reads, reset it to its default value before we attempt a non-blocking + // read, otherwise the scheduler will just time us out before we can read + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Time{}) + } + if err == nil { + err = connCheck(conn) + } + if err != nil { + errLog.Print("closing bad idle connection: ", err) + mc.Close() + return driver.ErrBadConn + } + } + for { var size int if pktLen >= maxPacketSize { @@ -119,33 +160,47 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) + mc.cleanup() errLog.Print(ErrMalformPkt) } else { + if cerr := mc.canceled.Value(); cerr != nil { + return cerr + } + if n == 0 && pktLen == len(data)-4 { + // only for the first loop iteration when nothing was written yet + return errBadConnNoWrite + } + mc.cleanup() errLog.Print(err) } - return driver.ErrBadConn + return ErrInvalidConn } } /****************************************************************************** -* Initialisation Process * +* Initialization Process * ******************************************************************************/ // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { + data, err = mc.readPacket() if err != nil { - return nil, err + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, "", driver.ErrBadConn + } + return } if data[0] == iERR { - return nil, mc.handleErrorPacket(data) + return nil, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, fmt.Errorf( + return nil, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -157,7 +212,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] - cipher := data[pos : pos+8] + authData := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -165,10 +220,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { - return nil, ErrOldProtocol + return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, ErrNoTLS + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } } pos += 2 @@ -192,32 +251,32 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. - cipher = append(cipher, data[pos:pos+12]...) + authData = append(authData, data[pos:pos+12]...) + pos += 13 - // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise - // - //if data[len(data)-1] == 0 { - // return - //} - //return ErrMalformPkt + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + plugin = string(data[pos : pos+end]) + } else { + plugin = string(data[pos:]) + } // make a memory safe copy of the cipher slice var b [20]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -241,10 +300,17 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientMultiStatements } - // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + // encode length of the auth plugin data + var authRespLEIBuf [9]byte + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + clientFlags |= clientPluginAuthLenEncClientData + } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -253,11 +319,11 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // ClientFlags [32 bit] @@ -295,6 +361,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { if err := tlsConn.Handshake(); err != nil { return err } + mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn } @@ -312,9 +379,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[pos] = 0x00 pos++ - // ScrambleBuffer [length encoded integer] - data[pos] = byte(len(scrambleBuff)) - pos += 1 + copy(data[pos+1:], scrambleBuff) + // Auth Data [length encoded integer] + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], authResp) // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -323,52 +390,26 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pos++ } - // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], plugin) data[pos] = 0x00 + pos++ // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(data[:pos]) } -// Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { - // User password - scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { + pktLen := 4 + len(authData) + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } - // Add the scrambled password [null terminated string] - copy(data[4:], scrambleBuff) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Client clear text authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { - // Calculate the packet length and add a tailing 0 - pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn - } - - // Add the clear password [null terminated string] - copy(data[4:], mc.cfg.Passwd) - data[4+pktLen-1] = 0x00 - + // Add the auth data [EOF] + copy(data[4:], authData) return mc.writePacket(data) } @@ -380,11 +421,11 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // Add command byte @@ -399,11 +440,11 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // Add command byte @@ -420,11 +461,11 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // Add command byte @@ -444,37 +485,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { * Result Packets * ******************************************************************************/ +func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { + data, err := mc.readPacket() + if err != nil { + return nil, "", err + } + + // packet indicator + switch data[0] { + + case iOK: + return nil, "", mc.handleOkPacket(data) + + case iAuthMoreData: + return data[1:], "", err + + case iEOF: + if len(data) == 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, "mysql_old_password", nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", ErrMalformPkt + } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", mc.handleErrorPacket(data) + } +} + // Returns error if Packet is not an 'Result OK'-Packet func (mc *mysqlConn) readResultOK() error { data, err := mc.readPacket() - if err == nil { - // packet indicator - switch data[0] { - - case iOK: - return mc.handleOkPacket(data) - - case iEOF: - if len(data) > 1 { - plugin := string(data[1:bytes.IndexByte(data, 0x00)]) - if plugin == "mysql_old_password" { - // using old_passwords - return ErrOldPassword - } else if plugin == "mysql_clear_password" { - // using clear text password - return ErrCleartextPassword - } else { - return ErrUnknownPlugin - } - } else { - return ErrOldPassword - } - - default: // Error otherwise - return mc.handleErrorPacket(data) - } + if err != nil { + return err } - return err + + if data[0] == iOK { + return mc.handleOkPacket(data) + } + return mc.handleErrorPacket(data) } // Result Set Header Packet @@ -517,6 +571,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // Error Number [16 bit uint] errno := binary.LittleEndian.Uint16(data[1:3]) + // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { + // Oops; we are connected to a read-only connection, and won't be able + // to issue any write statements. Since RejectReadOnly is configured, + // we throw away this connection hoping this one would have write + // permission. This is specifically for a possible race condition + // during failover (e.g. on AWS Aurora). See README.md for more. + // + // We explicitly close the connection before returning + // driver.ErrBadConn to ensure that `database/sql` purges this + // connection and initiates a new one for next statement next time. + mc.Close() + return driver.ErrBadConn + } + pos := 3 // SQL State [optional: # + 5bytes string] @@ -551,19 +621,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if err := mc.discardResults(); err != nil { - return err - } - - // warning count [2 bytes] - if !mc.strict { + if mc.status&statusMoreResultsExists != 0 { return nil } - pos := 1 + n + m + 2 - if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - return mc.getWarnings() - } + // warning count [2 bytes] + return nil } @@ -635,14 +698,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } + pos += n // Filler [uint8] + pos++ + // Charset [charset, collation uint8] + columns[i].charSet = data[pos] + pos += 2 + // Length [uint32] - pos += n + 1 + 2 + 4 + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 // Field type [uint8] - columns[i].fieldType = data[pos] + columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] @@ -665,6 +735,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc + if rows.rs.done { + return io.EOF + } + data, err := mc.readPacket() if err != nil { return err @@ -674,10 +748,10 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil return io.EOF } if data[0] == iERR { @@ -699,7 +773,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { if !mc.parseTime { continue } else { - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( @@ -729,16 +803,19 @@ func (rows *textRows) readRow(dest []driver.Value) error { func (mc *mysqlConn) readUntilEOF() error { for { data, err := mc.readPacket() - - // No Err and no EOF Packet - if err == nil && data[0] != iEOF { - continue - } - if err == nil && data[0] == iEOF && len(data) == 5 { - mc.status = readStatus(data[3:]) + if err != nil { + return err } - return err // Err or EOF + switch data[0] { + case iERR: + return mc.handleErrorPacket(data) + case iEOF: + if len(data) == 5 { + mc.status = readStatus(data[3:]) + } + return nil + } } } @@ -768,14 +845,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // Reserved [8 bit] // Warning count [16 bit uint] - if !stmt.mc.strict { - return columnCount, nil - } - // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { - return columnCount, stmt.mc.getWarnings() - } return columnCount, nil } return 0, err @@ -783,7 +853,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { - maxLen := stmt.mc.maxPacketAllowed - 1 + maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen // After the header (bytes 0-3) follows before the data: @@ -792,7 +862,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // 2 bytes paramID const dataOffset = 1 + 4 + 2 - // Can not use the write buffer since + // Cannot use the write buffer since // a) the buffer is too small // b) it is in use data := make([]byte, 4+1+4+2+len(arg)) @@ -847,20 +917,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc + // Determine threshold dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + // Reset packet-sequence mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. } - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite } // command [1 byte] @@ -885,7 +963,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos := minPktLen var nullMask []byte - if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here @@ -894,10 +972,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. pos += maskLen } else { nullMask = data[pos : pos+maskLen] - for i := 0; i < maskLen; i++ { + for i := range nullMask { nullMask[i] = 0 } pos += maskLen @@ -919,7 +998,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } @@ -927,7 +1006,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // cache types and values switch v := arg.(type) { case int64: - paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -942,8 +1021,24 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) } + case uint64: + paramTypes[i+i] = byte(fieldTypeLongLong) + paramTypes[i+i+1] = 0x80 // type is unsigned + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + case float64: - paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -959,7 +1054,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case bool: - paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { @@ -971,10 +1066,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case []byte: // Common case (non-nil value) first if v != nil { - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -989,14 +1084,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1008,23 +1103,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case time.Time: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - var val []byte + var a [64]byte + var b = a[:0] + if v.IsZero() { - val = []byte("0000-00-00") + b = append(b, "0000-00-00"...) } else { - val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) + b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) } paramValues = appendLengthEncodedInteger(paramValues, - uint64(len(val)), + uint64(len(b)), ) - paramValues = append(paramValues, val...) + paramValues = append(paramValues, b...) default: - return fmt.Errorf("can not convert type: %T", arg) + return fmt.Errorf("cannot convert type: %T", arg) } } @@ -1032,7 +1129,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } } pos += len(paramValues) @@ -1057,8 +1157,6 @@ func (mc *mysqlConn) discardResults() error { if err := mc.readUntilEOF(); err != nil { return err } - } else { - mc.status &^= statusMoreResultsExists } } return nil @@ -1076,16 +1174,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil return io.EOF } + mc := rows.mc rows.mc = nil // Error otherwise - return rows.mc.handleErrorPacket(data) + return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] @@ -1101,14 +1200,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } // Convert to byte-coded string - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1117,7 +1216,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeShort, fieldTypeYear: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1126,7 +1225,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeInt24, fieldTypeLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1135,7 +1234,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeLongLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1149,7 +1248,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeFloat: - dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 continue @@ -1189,10 +1288,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case isNull: dest[i] = nil continue - case rows.columns[i].fieldType == fieldTypeTime: + case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: @@ -1200,18 +1299,18 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) case rows.mc.parseTime: dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 - if rows.columns[i].fieldType == fieldTypeDate { + if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: @@ -1219,11 +1318,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) } if err == nil { @@ -1235,7 +1334,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/packets_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/packets_test.go new file mode 100644 index 0000000..b61e4db --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/packets_test.go @@ -0,0 +1,336 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "errors" + "net" + "testing" + "time" +) + +var ( + errConnClosed = errors.New("connection is closed") + errConnTooManyReads = errors.New("too many reads") + errConnTooManyWrites = errors.New("too many writes") +) + +// struct to mock a net.Conn for testing purposes +type mockConn struct { + laddr net.Addr + raddr net.Addr + data []byte + written []byte + queuedReplies [][]byte + closed bool + read int + reads int + writes int + maxReads int + maxWrites int +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.reads++ + if m.maxReads > 0 && m.reads > m.maxReads { + return 0, errConnTooManyReads + } + + n = copy(b, m.data) + m.read += n + m.data = m.data[n:] + return +} +func (m *mockConn) Write(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.writes++ + if m.maxWrites > 0 && m.writes > m.maxWrites { + return 0, errConnTooManyWrites + } + + n = len(b) + m.written = append(m.written, b...) + + if n > 0 && len(m.queuedReplies) > 0 { + m.data = m.queuedReplies[0] + m.queuedReplies = m.queuedReplies[1:] + } + return +} +func (m *mockConn) Close() error { + m.closed = true + return nil +} +func (m *mockConn) LocalAddr() net.Addr { + return m.laddr +} +func (m *mockConn) RemoteAddr() net.Addr { + return m.raddr +} +func (m *mockConn) SetDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetReadDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// make sure mockConn implements the net.Conn interface +var _ net.Conn = new(mockConn) + +func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: NewConfig(), + netConn: conn, + closech: make(chan struct{}), + maxAllowedPacket: defaultMaxAllowedPacket, + sequence: sequence, + } + return conn, mc +} + +func TestReadPacketSingleByte(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 1 { + t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet)) + } + if packet[0] != 0xff { + t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0]) + } +} + +func TestReadPacketWrongSequenceID(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + // too low sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + mc.sequence = 1 + _, err := mc.readPacket() + if err != ErrPktSync { + t.Errorf("expected ErrPktSync, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // too high sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} + _, err = mc.readPacket() + if err != ErrPktSyncMul { + t.Errorf("expected ErrPktSyncMul, got %v", err) + } +} + +func TestReadPacketSplit(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + data := make([]byte, maxPacketSize*2+4*3) + const pkt2ofs = maxPacketSize + 4 + const pkt3ofs = 2 * (maxPacketSize + 4) + + // case 1: payload has length maxPacketSize + data = data[:pkt2ofs+4] + + // 1st packet has maxPacketSize length and sequence id 0 + // ff ff ff 00 ... + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + + // mark the payload start and end of 1st packet so that we can check if the + // content was correctly appended + data[4] = 0x11 + data[maxPacketSize+3] = 0x22 + + // 2nd packet has payload length 0 and squence id 1 + // 00 00 00 01 + data[pkt2ofs+3] = 0x01 + + conn.data = data + conn.maxReads = 3 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize-1] != 0x22 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1]) + } + + // case 2: payload has length which is a multiple of maxPacketSize + data = data[:cap(data)] + + // 2nd packet now has maxPacketSize length + data[pkt2ofs] = 0xff + data[pkt2ofs+1] = 0xff + data[pkt2ofs+2] = 0xff + + // mark the payload start and end of the 2nd packet + data[pkt2ofs+4] = 0x33 + data[pkt2ofs+maxPacketSize+3] = 0x44 + + // 3rd packet has payload length 0 and squence id 2 + // 00 00 00 02 + data[pkt3ofs+3] = 0x02 + + conn.data = data + conn.reads = 0 + conn.maxReads = 5 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 2*maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[2*maxPacketSize-1] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1]) + } + + // case 3: payload has a length larger maxPacketSize, which is not an exact + // multiple of it + data = data[:pkt2ofs+4+42] + data[pkt2ofs] = 0x2a + data[pkt2ofs+1] = 0x00 + data[pkt2ofs+2] = 0x00 + data[pkt2ofs+4+41] = 0x44 + + conn.data = data + conn.reads = 0 + conn.maxReads = 4 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize+42 { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize+41] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41]) + } +} + +func TestReadPacketFail(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + closech: make(chan struct{}), + } + + // illegal empty (stand-alone) packet + conn.data = []byte{0x00, 0x00, 0x00, 0x00} + conn.maxReads = 1 + _, err := mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read header + conn.closed = true + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.closed = false + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read body + conn.maxReads = 1 + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } +} + +// https://github.com/go-sql-driver/mysql/pull/801 +// not-NUL terminated plugin_name in init packet +func TestRegression801(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: new(Config), + sequence: 42, + closech: make(chan struct{}), + } + + conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, + 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77, + 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, + 112, 97, 115, 115, 119, 111, 114, 100} + conn.maxReads = 1 + + authData, pluginName, err := mc.readHandshakePacket() + if err != nil { + t.Fatalf("got error: %v", err) + } + + if pluginName != "mysql_native_password" { + t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) + } + + expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114, + 47, 85, 75, 109, 99, 51, 77, 50, 64} + if !bytes.Equal(authData, expectedAuthData) { + t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/result.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/result.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/result.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/result.go diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/rows.go new file mode 100644 index 0000000..888bdb5 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/rows.go @@ -0,0 +1,223 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "io" + "math" + "reflect" +) + +type resultSet struct { + columns []mysqlField + columnNames []string + done bool +} + +type mysqlRows struct { + mc *mysqlConn + rs resultSet + finish func() +} + +type binaryRows struct { + mysqlRows +} + +type textRows struct { + mysqlRows +} + +func (rows *mysqlRows) Columns() []string { + if rows.rs.columnNames != nil { + return rows.rs.columnNames + } + + columns := make([]string, len(rows.rs.columns)) + if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { + for i := range columns { + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name + } else { + columns[i] = rows.rs.columns[i].name + } + } + } else { + for i := range columns { + columns[i] = rows.rs.columns[i].name + } + } + + rows.rs.columnNames = columns + return columns +} + +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + return rows.rs.columns[i].typeDatabaseName() +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + +func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + + mc := rows.mc + if mc == nil { + return nil + } + if err := mc.error(); err != nil { + return err + } + + // flip the buffer for this connection if we need to drain it. + // note that for a successful query (i.e. one where rows.next() + // has been called until it returns false), `rows.mc` will be nil + // by the time the user calls `(*Rows).Close`, so we won't reach this + // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.flip() + + // Remove unread packets from stream + if !rows.rs.done { + err = mc.readUntilEOF() + } + if err == nil { + if err = mc.discardResults(); err != nil { + return err + } + } + + rows.mc = nil + return err +} + +func (rows *mysqlRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *mysqlRows) nextResultSet() (int, error) { + if rows.mc == nil { + return 0, io.EOF + } + if err := rows.mc.error(); err != nil { + return 0, err + } + + // Remove unread packets from stream + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return 0, err + } + rows.rs.done = true + } + + if !rows.HasNextResultSet() { + rows.mc = nil + return 0, io.EOF + } + rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() +} + +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } + + rows.rs.done = true + } +} + +func (rows *binaryRows) NextResultSet() error { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + +func (rows *binaryRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if err := mc.error(); err != nil { + return err + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} + +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + +func (rows *textRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if err := mc.error(); err != nil { + return err + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/statement.go new file mode 100644 index 0000000..f7e3709 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/statement.go @@ -0,0 +1,204 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "fmt" + "io" + "reflect" +) + +type mysqlStmt struct { + mc *mysqlConn + id uint32 + paramCount int +} + +func (stmt *mysqlStmt) Close() error { + if stmt.mc == nil || stmt.mc.closed.IsSet() { + // driver.Stmt.Close can be called more than once, thus this function + // has to be idempotent. + // See also Issue #450 and golang/go#16019. + //errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + stmt.mc = nil + return err +} + +func (stmt *mysqlStmt) NumInput() int { + return stmt.paramCount +} + +func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { + return converter{} +} + +func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } + + mc := stmt.mc + + mc.affectedRows = 0 + mc.insertId = 0 + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err + } + } + + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil +} + +func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { + if stmt.mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } + + mc := stmt.mc + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + rows := new(binaryRows) + + if resLen > 0 { + rows.mc = mc + rows.rs.columns, err = mc.readColumns(resLen) + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } + } + + return rows, err +} + +type converter struct{} + +// ConvertValue mirrors the reference/default converter in database/sql/driver +// with _one_ exception. We support uint64 with their high bit and the default +// implementation does not. This function should be kept in sync with +// database/sql/driver defaultConverter.ConvertValue() except for that +// deliberate difference. +func (c converter) ConvertValue(v interface{}) (driver.Value, error) { + if driver.IsValue(v) { + return v, nil + } + + if vr, ok := v.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Ptr: + // indirect pointers + if rv.IsNil() { + return nil, nil + } else { + return c.ConvertValue(rv.Elem().Interface()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return rv.Uint(), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) + case reflect.String: + return rv.String(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This is an exact copy of the same-named unexported function from the +// database/sql package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/statement_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/statement_test.go new file mode 100644 index 0000000..4b9914f --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/statement_test.go @@ -0,0 +1,126 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "testing" +) + +func TestConvertDerivedString(t *testing.T) { + type derived string + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Derived string type not convertible", err) + } + + if output != "value" { + t.Fatalf("Derived string type not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedByteSlice(t *testing.T) { + type derived []uint8 + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Byte slice not convertible", err) + } + + if bytes.Compare(output.([]byte), []byte("value")) != 0 { + t.Fatalf("Byte slice not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedUnsupportedSlice(t *testing.T) { + type derived []int + + _, err := converter{}.ConvertValue(derived{1}) + if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { + t.Fatal("Unexpected error", err) + } +} + +func TestConvertDerivedBool(t *testing.T) { + type derived bool + + output, err := converter{}.ConvertValue(derived(true)) + if err != nil { + t.Fatal("Derived bool type not convertible", err) + } + + if output != true { + t.Fatalf("Derived bool type not converted, got %#v %T", output, output) + } +} + +func TestConvertPointer(t *testing.T) { + str := "value" + + output, err := converter{}.ConvertValue(&str) + if err != nil { + t.Fatal("Pointer type not convertible", err) + } + + if output != "value" { + t.Fatalf("Pointer type not converted, got %#v %T", output, output) + } +} + +func TestConvertSignedIntegers(t *testing.T) { + values := []interface{}{ + int8(-42), + int16(-42), + int32(-42), + int64(-42), + int(-42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(-42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } +} + +func TestConvertUnsignedIntegers(t *testing.T) { + values := []interface{}{ + uint8(42), + uint16(42), + uint32(42), + uint64(42), + uint(42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != uint64(42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } + + output, err := converter{}.ConvertValue(^uint64(0)) + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if output != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/transaction.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/transaction.go similarity index 88% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/transaction.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/transaction.go index 33c749b..417d727 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/transaction.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/utils.go similarity index 69% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/utils.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/utils.go index d523b7f..cfa10e9 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/utils.go @@ -9,23 +9,32 @@ package mysql import ( - "crypto/sha1" "crypto/tls" + "database/sql" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" + "strconv" "strings" + "sync" + "sync/atomic" "time" ) +// Registry for custom tls.Configs var ( - tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config ) // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // +// Note: The provided tls.Config is exclusively owned by the driver after +// registering it. +// // rootCertPool := x509.NewCertPool() // pem, err := ioutil.ReadFile("/path/ca-cert.pem") // if err != nil { @@ -47,23 +56,36 @@ var ( // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { - if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { return fmt.Errorf("key '%s' is reserved", key) } - if tlsConfigRegister == nil { - tlsConfigRegister = make(map[string]*tls.Config) + tlsConfigLock.Lock() + if tlsConfigRegistry == nil { + tlsConfigRegistry = make(map[string]*tls.Config) } - tlsConfigRegister[key] = config + tlsConfigRegistry[key] = config + tlsConfigLock.Unlock() return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { - if tlsConfigRegister != nil { - delete(tlsConfigRegister, key) + tlsConfigLock.Lock() + if tlsConfigRegistry != nil { + delete(tlsConfigRegistry, key) } + tlsConfigLock.Unlock() +} + +func getTLSConfigClone(key string) (config *tls.Config) { + tlsConfigLock.RLock() + if v, ok := tlsConfigRegistry[key]; ok { + config = v.Clone() + } + tlsConfigLock.RUnlock() + return } // Returns the bool value of the input. @@ -80,119 +102,6 @@ func readBool(input string) (value bool, valid bool) { return } -/****************************************************************************** -* Authentication * -******************************************************************************/ - -// Encrypt password using 4.1+ method -func scramblePassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - // stage1Hash = SHA1(password) - crypt := sha1.New() - crypt.Write(password) - stage1 := crypt.Sum(nil) - - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash - crypt.Reset() - crypt.Write(stage1) - hash := crypt.Sum(nil) - - // outer Hash - crypt.Reset() - crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) - - // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] - } - return scramble -} - -// Encrypt password using pre 4.1 (old password) method -// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c -type myRnd struct { - seed1, seed2 uint32 -} - -const myRndMaxVal = 0x3FFFFFFF - -// Pseudo random number generator -func newMyRnd(seed1, seed2 uint32) *myRnd { - return &myRnd{ - seed1: seed1 % myRndMaxVal, - seed2: seed2 % myRndMaxVal, - } -} - -// Tested to be equivalent to MariaDB's floating point variant -// http://play.golang.org/p/QHvhd4qved -// http://play.golang.org/p/RG0q4ElWDx -func (r *myRnd) NextByte() byte { - r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal - r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal - - return byte(uint64(r.seed1) * 31 / myRndMaxVal) -} - -// Generate binary hash from byte string using insecure pre 4.1 method -func pwHash(password []byte) (result [2]uint32) { - var add uint32 = 7 - var tmp uint32 - - result[0] = 1345345333 - result[1] = 0x12345671 - - for _, c := range password { - // skip spaces and tabs in password - if c == ' ' || c == '\t' { - continue - } - - tmp = uint32(c) - result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) - result[1] += (result[1] << 8) ^ result[0] - add += tmp - } - - // Remove sign bit (1<<31)-1) - result[0] &= 0x7FFFFFFF - result[1] &= 0x7FFFFFFF - - return -} - -// Encrypt password using insecure pre 4.1 method -func scrambleOldPassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - scramble = scramble[:8] - - hashPw := pwHash(password) - hashSc := pwHash(scramble) - - r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) - - var out [8]byte - for i := range out { - out[i] = r.NextByte() + 64 - } - - mask := r.NextByte() - for i := range out { - out[i] ^= mask - } - - return out[:] -} - /****************************************************************************** * Time related utils * ******************************************************************************/ @@ -321,87 +230,104 @@ var zeroDateTime = []byte("0000-00-00 00:00:00.000000") const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" -func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { +func appendMicrosecs(dst, src []byte, decimals int) []byte { + if decimals <= 0 { + return dst + } + if len(src) == 0 { + return append(dst, ".000000"[:decimals+1]...) + } + + microsecs := binary.LittleEndian.Uint32(src[:4]) + p1 := byte(microsecs / 10000) + microsecs -= 10000 * uint32(p1) + p2 := byte(microsecs / 100) + microsecs -= 100 * uint32(p2) + p3 := byte(microsecs) + + switch decimals { + default: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], digits01[p3], + ) + case 1: + return append(dst, '.', + digits10[p1], + ) + case 2: + return append(dst, '.', + digits10[p1], digits01[p1], + ) + case 3: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], + ) + case 4: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + ) + case 5: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], + ) + } +} + +func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { // length expects the deterministic length of the zero value, // negative time and 100+ hours are automatically added if needed if len(src) == 0 { - if justTime { - return zeroDateTime[11 : 11+length], nil - } return zeroDateTime[:length], nil } - var dst []byte // return value - var pt, p1, p2, p3 byte // current digit pair - var zOffs byte // offset of value in zeroDateTime - if justTime { - switch length { - case - 8, // time (can be up to 10 when negative and 100+ hours) - 10, 11, 12, 13, 14, 15: // time with fractional seconds - default: - return nil, fmt.Errorf("illegal TIME length %d", length) + var dst []byte // return value + var p1, p2, p3 byte // current digit pair + + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" } - switch len(src) { - case 8, 12: - default: - return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) - } - // +2 to enable negative time and 100+ hours - dst = make([]byte, 0, length+2) - if src[0] == 1 { - dst = append(dst, '-') - } - if src[1] != 0 { - hour := uint16(src[1])*24 + uint16(src[5]) - pt = byte(hour / 100) - p1 = byte(hour - 100*uint16(pt)) - dst = append(dst, digits01[pt]) - } else { - p1 = src[5] - } - zOffs = 11 - src = src[6:] - } else { - switch length { - case 10, 19, 21, 22, 23, 24, 25, 26: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s length %d", t, length) - } - switch len(src) { - case 4, 7, 11: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) - } - dst = make([]byte, 0, length) - // start with the date - year := binary.LittleEndian.Uint16(src[:2]) - pt = byte(year / 100) - p1 = byte(year - 100*uint16(pt)) - p2, p3 = src[2], src[3] - dst = append(dst, - digits10[pt], digits01[pt], - digits10[p1], digits01[p1], '-', - digits10[p2], digits01[p2], '-', - digits10[p3], digits01[p3], - ) - if length == 10 { - return dst, nil - } - if len(src) == 4 { - return append(dst, zeroDateTime[10:length]...), nil - } - dst = append(dst, ' ') - p1 = src[4] // hour - src = src[5:] + return nil, fmt.Errorf("illegal %s length %d", t, length) } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt := year / 100 + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + // p1 is 2-digit hour, src is after hour p2, p3 = src[0], src[1] dst = append(dst, @@ -409,51 +335,49 @@ func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value digits10[p2], digits01[p2], ':', digits10[p3], digits01[p3], ) - if length <= byte(len(dst)) { - return dst, nil - } - src = src[2:] + return appendMicrosecs(dst, src[2:], int(length)-20), nil +} + +func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed if len(src) == 0 { - return append(dst, zeroDateTime[19:zOffs+length]...), nil + return zeroDateTime[11 : 11+length], nil } - microsecs := binary.LittleEndian.Uint32(src[:4]) - p1 = byte(microsecs / 10000) - microsecs -= 10000 * uint32(p1) - p2 = byte(microsecs / 100) - microsecs -= 100 * uint32(p2) - p3 = byte(microsecs) - switch decimals := zOffs + length - 20; decimals { + var dst []byte // return value + + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds default: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - digits10[p3], digits01[p3], - ), nil - case 1: - return append(dst, '.', - digits10[p1], - ), nil - case 2: - return append(dst, '.', - digits10[p1], digits01[p1], - ), nil - case 3: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], - ), nil - case 4: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - ), nil - case 5: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - digits10[p3], - ), nil + return nil, fmt.Errorf("illegal TIME length %d", length) } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) + } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + days := binary.LittleEndian.Uint32(src[1:5]) + hours := int64(days)*24 + int64(src[5]) + + if hours >= 100 { + dst = strconv.AppendInt(dst, hours, 10) + } else { + dst = append(dst, digits10[hours], digits01[hours]) + } + + min, sec := src[6], src[7] + dst = append(dst, ':', + digits10[min], digits01[min], ':', + digits10[sec], digits01[sec], + ) + return appendMicrosecs(dst, src[8:], int(length)-9), nil } /****************************************************************************** @@ -519,7 +443,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { // Check data length if len(b) >= n { - return b[n-int(num) : n], false, n, nil + return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } @@ -548,8 +472,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - switch b[0] { + switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 @@ -738,3 +662,94 @@ func escapeStringQuotes(buf []byte, v string) []byte { return buf[:pos] } + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} + +// atomicBool is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _noCopy noCopy + value uint32 +} + +// IsSet returns whether the current boolean value is true +func (ab *atomicBool) IsSet() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Set sets the value of the bool regardless of the previous value +func (ab *atomicBool) Set(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// TrySet sets the value of the bool and returns whether the value changed +func (ab *atomicBool) TrySet(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) == 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} + +// atomicError is a wrapper for atomically accessed error values +type atomicError struct { + _noCopy noCopy + value atomic.Value +} + +// Set sets the error value regardless of the previous value. +// The value must not be nil +func (ae *atomicError) Set(value error) { + ae.value.Store(value) +} + +// Value returns the current error value +func (ae *atomicError) Value() error { + if v := ae.value.Load(); v != nil { + // this will panic if the value doesn't implement the error interface + return v.(error) + } + return nil +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/utils_test.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/utils_test.go new file mode 100644 index 0000000..8951a7a --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/go-sql-driver/mysql/utils_test.go @@ -0,0 +1,334 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/binary" + "testing" + "time" +) + +func TestScanNullTime(t *testing.T) { + var scanTests = []struct { + in interface{} + error bool + valid bool + time time.Time + }{ + {tDate, false, true, tDate}, + {sDate, false, true, tDate}, + {[]byte(sDate), false, true, tDate}, + {tDateTime, false, true, tDateTime}, + {sDateTime, false, true, tDateTime}, + {[]byte(sDateTime), false, true, tDateTime}, + {tDate0, false, true, tDate0}, + {sDate0, false, true, tDate0}, + {[]byte(sDate0), false, true, tDate0}, + {sDateTime0, false, true, tDate0}, + {[]byte(sDateTime0), false, true, tDate0}, + {"", true, false, tDate0}, + {"1234", true, false, tDate0}, + {0, true, false, tDate0}, + } + + var nt = NullTime{} + var err error + + for _, tst := range scanTests { + err = nt.Scan(tst.in) + if (err != nil) != tst.error { + t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil)) + } + if nt.Valid != tst.valid { + t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid) + } + if nt.Time != tst.time { + t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time) + } + } +} + +func TestLengthEncodedInteger(t *testing.T) { + var integerTests = []struct { + num uint64 + encoded []byte + }{ + {0x0000000000000000, []byte{0x00}}, + {0x0000000000000012, []byte{0x12}}, + {0x00000000000000fa, []byte{0xfa}}, + {0x0000000000000100, []byte{0xfc, 0x00, 0x01}}, + {0x0000000000001234, []byte{0xfc, 0x34, 0x12}}, + {0x000000000000ffff, []byte{0xfc, 0xff, 0xff}}, + {0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}}, + {0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}}, + {0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}}, + {0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}}, + {0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}}, + {0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + + for _, tst := range integerTests { + num, isNull, numLen := readLengthEncodedInteger(tst.encoded) + if isNull { + t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num) + } + if num != tst.num { + t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num) + } + if numLen != len(tst.encoded) { + t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen) + } + encoded := appendLengthEncodedInteger(nil, num) + if !bytes.Equal(encoded, tst.encoded) { + t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded) + } + } +} + +func TestFormatBinaryDateTime(t *testing.T) { + rawDate := [11]byte{} + binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years + rawDate[2] = 12 // months + rawDate[3] = 30 // days + rawDate[4] = 15 // hours + rawDate[5] = 46 // minutes + rawDate[6] = 23 // seconds + binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds + expect := func(expected string, inlen, outlen uint8) { + actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for length in %d, out %d", + expected, actual, inlen, outlen, + ) + } + } + expect("0000-00-00", 0, 10) + expect("0000-00-00 00:00:00", 0, 19) + expect("1978-12-30", 4, 10) + expect("1978-12-30 15:46:23", 7, 19) + expect("1978-12-30 15:46:23.987654", 11, 26) +} + +func TestFormatBinaryTime(t *testing.T) { + expect := func(expected string, src []byte, outlen uint8) { + actual, _ := formatBinaryTime(src, outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for src=%q and outlen=%d", + expected, actual, src, outlen) + } + } + + // binary format: + // sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4) + + // Zeros + expect("00:00:00", []byte{}, 8) + expect("00:00:00.0", []byte{}, 10) + expect("00:00:00.000000", []byte{}, 15) + + // Without micro(4) + expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11) + expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8) + expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8) + expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8) + + // With micro(4) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11) + expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15) +} + +func TestEscapeBackslash(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesBackslash([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringBackslash([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\\0bar", "foo\x00bar") + expect("foo\\nbar", "foo\nbar") + expect("foo\\rbar", "foo\rbar") + expect("foo\\Zbar", "foo\x1abar") + expect("foo\\\"bar", "foo\"bar") + expect("foo\\\\bar", "foo\\bar") + expect("foo\\'bar", "foo'bar") +} + +func TestEscapeQuotes(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesQuotes([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringQuotes([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\x00bar", "foo\x00bar") // not affected + expect("foo\nbar", "foo\nbar") // not affected + expect("foo\rbar", "foo\rbar") // not affected + expect("foo\x1abar", "foo\x1abar") // not affected + expect("foo''bar", "foo'bar") // affected + expect("foo\"bar", "foo\"bar") // not affected +} + +func TestAtomicBool(t *testing.T) { + var ab atomicBool + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(true) + if ab.value != 1 { + t.Fatal("Set(true) did not set value to 1") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(false) + if ab.value != 0 { + t.Fatal("Set(false) did not set value to 0") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(false) + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + if ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to fail") + } + if !ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to succeed") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + if ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to fail") + } + if !ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to succeed") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯ +} + +func TestAtomicError(t *testing.T) { + var ae atomicError + if ae.Value() != nil { + t.Fatal("Expected value to be nil") + } + + ae.Set(ErrMalformPkt) + if v := ae.Value(); v != ErrMalformPkt { + if v == nil { + t.Fatal("Value is still nil") + } + t.Fatal("Error did not match") + } + ae.Set(ErrPktSync) + if ae.Value() == ErrMalformPkt { + t.Fatal("Error still matches old error") + } + if v := ae.Value(); v != ErrPktSync { + t.Fatal("Error did not match") + } +} + +func TestIsolationLevelMapping(t *testing.T) { + data := []struct { + level driver.IsolationLevel + expected string + }{ + { + level: driver.IsolationLevel(sql.LevelReadCommitted), + expected: "READ COMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelRepeatableRead), + expected: "REPEATABLE READ", + }, + { + level: driver.IsolationLevel(sql.LevelReadUncommitted), + expected: "READ UNCOMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelSerializable), + expected: "SERIALIZABLE", + }, + } + + for i, td := range data { + if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil { + t.Fatal(i, td.expected, actual, err) + } + } + + // check unsupported mapping + expectedErr := "mysql: unsupported isolation level: 7" + actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable)) + if actual != "" || err == nil { + t.Fatal("Expected error on unsupported isolation level") + } + if err.Error() != expectedErr { + t.Fatalf("Expected error to be %q, got %q", expectedErr, err) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/LICENSE similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/LICENSE rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/LICENSE diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/bind.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/bind.go similarity index 76% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/bind.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/bind.go index 564635c..0fdc443 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/bind.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/bind.go @@ -21,13 +21,13 @@ const ( // BindType returns the bindtype for a given database given a drivername. func BindType(driverName string) int { switch driverName { - case "postgres", "pgx": + case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres": return DOLLAR case "mysql": return QUESTION case "sqlite3": return QUESTION - case "oci8": + case "oci8", "ora", "goracle": return NAMED } return UNKNOWN @@ -43,27 +43,28 @@ func Rebind(bindType int, query string) string { return query } - qb := []byte(query) // Add space enough for 10 params before we have to allocate - rqb := make([]byte, 0, len(qb)+10) - j := 1 - for _, b := range qb { - if b == '?' { - switch bindType { - case DOLLAR: - rqb = append(rqb, '$') - case NAMED: - rqb = append(rqb, ':', 'a', 'r', 'g') - } - for _, b := range strconv.Itoa(j) { - rqb = append(rqb, byte(b)) - } - j++ - } else { - rqb = append(rqb, b) + rqb := make([]byte, 0, len(query)+10) + + var i, j int + + for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { + rqb = append(rqb, query[:i]...) + + switch bindType { + case DOLLAR: + rqb = append(rqb, '$') + case NAMED: + rqb = append(rqb, ':', 'a', 'r', 'g') } + + j++ + rqb = strconv.AppendInt(rqb, int64(j), 10) + + query = query[i+1:] } - return string(rqb) + + return string(append(rqb, query...)) } // Experimental implementation of Rebind which uses a bytes.Buffer. The code is @@ -112,7 +113,8 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { v := reflect.ValueOf(arg) t := reflectx.Deref(v.Type()) - if t.Kind() == reflect.Slice { + // []byte is a driver.Value type so it should not be expanded + if t.Kind() == reflect.Slice && t != reflect.TypeOf([]byte{}) { meta[i].length = v.Len() meta[i].v = v @@ -135,9 +137,9 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { } newArgs := make([]interface{}, 0, flatArgsCount) + buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount)) var arg, offset int - var buf bytes.Buffer for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { if arg >= len(meta) { @@ -163,13 +165,12 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { // write everything up to and including our ? character buf.WriteString(query[:offset+i+1]) - newArgs = append(newArgs, argMeta.v.Index(0).Interface()) - for si := 1; si < argMeta.length; si++ { buf.WriteString(", ?") - newArgs = append(newArgs, argMeta.v.Index(si).Interface()) } + newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) + // slice the query and reset the offset. this avoids some bookkeeping for // the write after the loop query = query[offset+i+1:] @@ -184,3 +185,24 @@ func In(query string, args ...interface{}) (string, []interface{}, error) { return buf.String(), newArgs, nil } + +func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { + switch val := v.Interface().(type) { + case []interface{}: + args = append(args, val...) + case []int: + for i := range val { + args = append(args, val[i]) + } + case []string: + for i := range val { + args = append(args, val[i]) + } + default: + for si := 0; si < vlen; si++ { + args = append(args, v.Index(si).Interface()) + } + } + + return args +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/doc.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/doc.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/doc.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/doc.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/named.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/named.go similarity index 92% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/named.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/named.go index 4df8095..69eb954 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/named.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/named.go @@ -36,6 +36,7 @@ func (n *NamedStmt) Close() error { } // Exec executes a named statement using the struct passed. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { @@ -45,6 +46,7 @@ func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { } // Query executes a named statement using the struct argument, returning rows. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { @@ -56,6 +58,7 @@ func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { // QueryRow executes a named statement against the database. Because sqlx cannot // create a *sql.Row with an error condition pre-set for binding errors, sqlx // returns a *sqlx.Row instead. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRow(arg interface{}) *Row { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { @@ -65,6 +68,7 @@ func (n *NamedStmt) QueryRow(arg interface{}) *Row { } // MustExec execs a NamedStmt, panicing on error +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) MustExec(arg interface{}) sql.Result { res, err := n.Exec(arg) if err != nil { @@ -74,6 +78,7 @@ func (n *NamedStmt) MustExec(arg interface{}) sql.Result { } // Queryx using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { r, err := n.Query(arg) if err != nil { @@ -84,11 +89,13 @@ func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { // QueryRowx this NamedStmt. Because of limitations with QueryRow, this is // an alias for QueryRow. +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowx(arg interface{}) *Row { return n.QueryRow(arg) } // Select using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { rows, err := n.Queryx(arg) if err != nil { @@ -100,6 +107,7 @@ func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { } // Get using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { r := n.QueryRowx(arg) return r.scanAny(dest, false) @@ -155,16 +163,18 @@ func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{ v = v.Elem() } - fields := m.TraversalsByName(v.Type(), names) - for i, t := range fields { + err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error { if len(t) == 0 { - return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg) + return fmt.Errorf("could not find name %s in %#v", names[i], arg) } + val := reflectx.FieldByIndexesReadOnly(v, t) arglist = append(arglist, val.Interface()) - } - return arglist, nil + return nil + }) + + return arglist, err } // like bindArgs, but for maps. @@ -250,7 +260,7 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e inName = true name = []byte{} // if we're in a name, and this is an allowed character, continue - } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_') && i != last { + } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { // append the byte to the name if we are in a name and not on the last byte name = append(name, b) // if we're in a name and it's not an allowed character, the name is done diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/named_context.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/named_context.go new file mode 100644 index 0000000..9405007 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/named_context.go @@ -0,0 +1,132 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" +) + +// A union interface of contextPreparer and binder, required to be able to +// prepare named statements with context (as the bindtype must be determined). +type namedPreparerContext interface { + PreparerContext + binder +} + +func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { + bindType := BindType(p.DriverName()) + q, args, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return nil, err + } + stmt, err := PreparexContext(ctx, p, q) + if err != nil { + return nil, err + } + return &NamedStmt{ + QueryString: q, + Params: args, + Stmt: stmt, + }, nil +} + +// ExecContext executes a named statement using the struct passed. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return *new(sql.Result), err + } + return n.Stmt.ExecContext(ctx, args...) +} + +// QueryContext executes a named statement using the struct argument, returning rows. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return nil, err + } + return n.Stmt.QueryContext(ctx, args...) +} + +// QueryRowContext executes a named statement against the database. Because sqlx cannot +// create a *sql.Row with an error condition pre-set for binding errors, sqlx +// returns a *sqlx.Row instead. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return &Row{err: err} + } + return n.Stmt.QueryRowxContext(ctx, args...) +} + +// MustExecContext execs a NamedStmt, panicing on error +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { + res, err := n.ExecContext(ctx, arg) + if err != nil { + panic(err) + } + return res +} + +// QueryxContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { + r, err := n.QueryContext(ctx, arg) + if err != nil { + return nil, err + } + return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err +} + +// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is +// an alias for QueryRow. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { + return n.QueryRowContext(ctx, arg) +} + +// SelectContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { + rows, err := n.QueryxContext(ctx, arg) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// GetContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { + r := n.QueryRowxContext(ctx, arg) + return r.scanAny(dest, false) +} + +// NamedQueryContext binds a named query and then runs Query on the result using the +// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with +// map[string]interface{} types. +func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.QueryxContext(ctx, q, args...) +} + +// NamedExecContext uses BindStruct to get a query executable by the driver and +// then runs Exec on the result. Returns an error from the binding +// or the query excution itself. +func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.ExecContext(ctx, q, args...) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go similarity index 65% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go index 5728011..73c21eb 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go @@ -1,20 +1,19 @@ // Package reflectx implements extensions to the standard reflect lib suitable -// for implementing marshaling and unmarshaling packages. The main Mapper type -// allows for Go-compatible named atribute access, including accessing embedded +// for implementing marshalling and unmarshalling packages. The main Mapper type +// allows for Go-compatible named attribute access, including accessing embedded // struct attributes and the ability to use functions and struct tags to // customize field names. // package reflectx import ( - "fmt" "reflect" "runtime" "strings" "sync" ) -// A FieldInfo is a collection of metadata about a struct field. +// A FieldInfo is metadata for a struct field. type FieldInfo struct { Index []int Path string @@ -41,7 +40,8 @@ func (f StructMap) GetByPath(path string) *FieldInfo { } // GetByTraversal returns a *FieldInfo for a given integer path. It is -// analagous to reflect.FieldByIndex. +// analogous to reflect.FieldByIndex, but using the cached traversal +// rather than re-executing the reflect machinery each time. func (f StructMap) GetByTraversal(index []int) *FieldInfo { if len(index) == 0 { return nil @@ -58,8 +58,8 @@ func (f StructMap) GetByTraversal(index []int) *FieldInfo { } // Mapper is a general purpose mapper of names to struct fields. A Mapper -// behaves like most marshallers, optionally obeying a field tag for name -// mapping and a function to provide a basic mapping of fields to names. +// behaves like most marshallers in the standard library, obeying a field tag +// for name mapping but also providing a basic transform function. type Mapper struct { cache map[reflect.Type]*StructMap tagName string @@ -68,8 +68,8 @@ type Mapper struct { mutex sync.Mutex } -// NewMapper returns a new mapper which optionally obeys the field tag given -// by tagName. If tagName is the empty string, it is ignored. +// NewMapper returns a new mapper using the tagName as its struct field tag. +// If tagName is the empty string, it is ignored. func NewMapper(tagName string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), @@ -127,7 +127,7 @@ func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { return r } -// FieldByName returns a field by the its mapped name as a reflect.Value. +// FieldByName returns a field by its mapped name as a reflect.Value. // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. // Returns zero Value if the name is not found. func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { @@ -166,27 +166,47 @@ func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { // traversals for each mapped name. Panics if t is not a struct or Indirectable // to a struct. Returns empty int slice for each name not found. func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { - t = Deref(t) - mustBe(t, reflect.Struct) - tm := m.TypeMap(t) - r := make([][]int, 0, len(names)) - for _, name := range names { - fi, ok := tm.Names[name] - if !ok { + m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { + if i == nil { r = append(r, []int{}) } else { - r = append(r, fi.Index) + r = append(r, i) } - } + + return nil + }) return r } -// FieldByIndexes returns a value for a particular struct traversal. +// TraversalsByNameFunc traverses the mapped names and calls fn with the index of +// each name and the struct traversal represented by that name. Panics if t is not +// a struct or Indirectable to a struct. Returns the first error returned by fn or nil. +func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error { + t = Deref(t) + mustBe(t, reflect.Struct) + tm := m.TypeMap(t) + for i, name := range names { + fi, ok := tm.Names[name] + if !ok { + if err := fn(i, nil); err != nil { + return err + } + } else { + if err := fn(i, fi.Index); err != nil { + return err + } + } + } + return nil +} + +// FieldByIndexes returns a value for the field given by the struct traversal +// for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { for _, i := range indexes { v = reflect.Indirect(v).Field(i) - // if this is a pointer, it's possible it is nil + // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { alloc := reflect.New(Deref(v.Type())) v.Set(alloc) @@ -225,13 +245,12 @@ type kinder interface { // mustBe checks a value against a kind, panicing with a reflect.ValueError // if the kind isn't that which is required. func mustBe(v kinder, expected reflect.Kind) { - k := v.Kind() - if k != expected { + if k := v.Kind(); k != expected { panic(&reflect.ValueError{Method: methodName(), Kind: k}) } } -// methodName is returns the caller of the function calling methodName +// methodName returns the caller of the function calling methodName func methodName() string { pc, _, _, _ := runtime.Caller(2) f := runtime.FuncForPC(pc) @@ -257,19 +276,92 @@ func apnd(is []int, i int) []int { return x } +type mapf func(string) string + +// parseName parses the tag and the target name for the given field using +// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the +// field's name to a target name, and tagMapFunc for mapping the tag to +// a target name. +func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { + // first, set the fieldName to the field's name + fieldName = field.Name + // if a mapFunc is set, use that to override the fieldName + if mapFunc != nil { + fieldName = mapFunc(fieldName) + } + + // if there's no tag to look for, return the field name + if tagName == "" { + return "", fieldName + } + + // if this tag is not set using the normal convention in the tag, + // then return the fieldname.. this check is done because according + // to the reflect documentation: + // If the tag does not have the conventional format, + // the value returned by Get is unspecified. + // which doesn't sound great. + if !strings.Contains(string(field.Tag), tagName+":") { + return "", fieldName + } + + // at this point we're fairly sure that we have a tag, so lets pull it out + tag = field.Tag.Get(tagName) + + // if we have a mapper function, call it on the whole tag + // XXX: this is a change from the old version, which pulled out the name + // before the tagMapFunc could be run, but I think this is the right way + if tagMapFunc != nil { + tag = tagMapFunc(tag) + } + + // finally, split the options from the name + parts := strings.Split(tag, ",") + fieldName = parts[0] + + return tag, fieldName +} + +// parseOptions parses options out of a tag string, skipping the name +func parseOptions(tag string) map[string]string { + parts := strings.Split(tag, ",") + options := make(map[string]string, len(parts)) + if len(parts) > 1 { + for _, opt := range parts[1:] { + // short circuit potentially expensive split op + if strings.Contains(opt, "=") { + kv := strings.Split(opt, "=") + options[kv[0]] = kv[1] + continue + } + options[opt] = "" + } + } + return options +} + // getMapping returns a mapping for the t type, using the tagName, mapFunc and // tagMapFunc to determine the canonical names of fields. -func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap { +func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { m := []*FieldInfo{} root := &FieldInfo{} queue := []typeQueue{} queue = append(queue, typeQueue{Deref(t), root, ""}) +QueueLoop: for len(queue) != 0 { // pop the first item off of the queue tq := queue[0] queue = queue[1:] + + // ignore recursive field + for p := tq.fi.Parent; p != nil; p = p.Parent { + if tq.fi.Field.Type == p.Field.Type { + continue QueueLoop + } + } + nChildren := 0 if tq.t.Kind() == reflect.Struct { nChildren = tq.t.NumField() @@ -278,55 +370,33 @@ func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) // iterate through all of its fields for fieldPos := 0; fieldPos < nChildren; fieldPos++ { + f := tq.t.Field(fieldPos) - fi := FieldInfo{} - fi.Field = f - fi.Zero = reflect.New(f.Type).Elem() - fi.Options = map[string]string{} - - var tag, name string - if tagName != "" && strings.Contains(string(f.Tag), tagName+":") { - tag = f.Tag.Get(tagName) - name = tag - } else { - if mapFunc != nil { - name = mapFunc(f.Name) - } - } - - parts := strings.Split(name, ",") - if len(parts) > 1 { - name = parts[0] - for _, opt := range parts[1:] { - kv := strings.Split(opt, "=") - if len(kv) > 1 { - fi.Options[kv[0]] = kv[1] - } else { - fi.Options[kv[0]] = "" - } - } - } - - if tagMapFunc != nil { - tag = tagMapFunc(tag) - } - - fi.Name = name - - if tq.pp == "" || (tq.pp == "" && tag == "") { - fi.Path = fi.Name - } else { - fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name) - } + // parse the tag and the target name using the mapping options for this field + tag, name := parseName(f, tagName, mapFunc, tagMapFunc) // if the name is "-", disabled via a tag, skip it if name == "-" { continue } + fi := FieldInfo{ + Field: f, + Name: name, + Zero: reflect.New(f.Type).Elem(), + Options: parseOptions(tag), + } + + // if the path is empty this path is just the name + if tq.pp == "" { + fi.Path = fi.Name + } else { + fi.Path = tq.pp + "." + fi.Name + } + // skip unexported fields - if len(f.PkgPath) != 0 { + if len(f.PkgPath) != 0 && !f.Anonymous { continue } diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/sqlx.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/sqlx.go similarity index 91% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/sqlx.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/sqlx.go index b1ba4cf..4385c3f 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/sqlx.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/sqlx.go @@ -10,6 +10,7 @@ import ( "path/filepath" "reflect" "strings" + "sync" "github.com/jmoiron/sqlx/reflectx" ) @@ -17,7 +18,7 @@ import ( // Although the NameMapper is convenient, in practice it should not // be relied on except for application code. If you are writing a library // that uses sqlx, you should be aware that the name mappings you expect -// can be overridded by your user's application. +// can be overridden by your user's application. // NameMapper is used to map column names to struct field names. By default, // it uses strings.ToLower to lowercase struct field names. It can be set @@ -30,8 +31,14 @@ var origMapper = reflect.ValueOf(NameMapper) // importers have time to customize the NameMapper. var mpr *reflectx.Mapper +// mprMu protects mpr. +var mprMu sync.Mutex + // mapper returns a valid mapper using the configured NameMapper func. func mapper() *reflectx.Mapper { + mprMu.Lock() + defer mprMu.Unlock() + if mpr == nil { mpr = reflectx.NewMapperFunc("db", NameMapper) } else if origMapper != reflect.ValueOf(NameMapper) { @@ -221,6 +228,14 @@ func (r *Row) Columns() ([]string, error) { return r.rows.Columns() } +// ColumnTypes returns the underlying sql.Rows.ColumnTypes(), or the deferred error +func (r *Row) ColumnTypes() ([]*sql.ColumnType, error) { + if r.err != nil { + return []*sql.ColumnType{}, r.err + } + return r.rows.ColumnTypes() +} + // Err returns the error encountered while scanning. func (r *Row) Err() error { return r.err @@ -289,21 +304,26 @@ func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, e } // NamedQuery using this DB. +// Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(db, query, arg) } // NamedExec using this DB. +// Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(db, query, arg) } // Select using this DB. +// Any placeholder parameters are replaced with supplied args. func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { return Select(db, dest, query, args...) } // Get using this DB. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { return Get(db, dest, query, args...) } @@ -328,6 +348,7 @@ func (db *DB) Beginx() (*Tx, error) { } // Queryx queries the database and returns an *sqlx.Rows. +// Any placeholder parameters are replaced with supplied args. func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := db.DB.Query(query, args...) if err != nil { @@ -337,12 +358,14 @@ func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { } // QueryRowx queries the database and returns an *sqlx.Row. +// Any placeholder parameters are replaced with supplied args. func (db *DB) QueryRowx(query string, args ...interface{}) *Row { rows, err := db.DB.Query(query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } // MustExec (panic) runs MustExec using this database. +// Any placeholder parameters are replaced with supplied args. func (db *DB) MustExec(query string, args ...interface{}) sql.Result { return MustExec(db, query, args...) } @@ -387,21 +410,25 @@ func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, e } // NamedQuery within a transaction. +// Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(tx, query, arg) } // NamedExec a named query within a transaction. +// Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(tx, query, arg) } // Select within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { return Select(tx, dest, query, args...) } // Queryx within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := tx.Tx.Query(query, args...) if err != nil { @@ -411,17 +438,21 @@ func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { } // QueryRowx within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { rows, err := tx.Tx.Query(query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } // Get within a transaction. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { return Get(tx, dest, query, args...) } // MustExec runs MustExec within a transaction. +// Any placeholder parameters are replaced with supplied args. func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { return MustExec(tx, query, args...) } @@ -478,28 +509,34 @@ func (s *Stmt) Unsafe() *Stmt { } // Select using the prepared statement. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) Select(dest interface{}, args ...interface{}) error { return Select(&qStmt{s}, dest, "", args...) } // Get using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func (s *Stmt) Get(dest interface{}, args ...interface{}) error { return Get(&qStmt{s}, dest, "", args...) } // MustExec (panic) using this statement. Note that the query portion of the error // output will be blank, as Stmt does not expose its query. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExec(args ...interface{}) sql.Result { return MustExec(&qStmt{s}, "", args...) } // QueryRowx using this statement. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowx(args ...interface{}) *Row { qs := &qStmt{s} return qs.QueryRowx("", args...) } // Queryx using this statement. +// Any placeholder parameters are replaced with supplied args. func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { qs := &qStmt{s} return qs.Queryx("", args...) @@ -564,7 +601,7 @@ func (r *Rows) StructScan(dest interface{}) error { return errors.New("must pass a pointer, not a value, to StructScan destination") } - v = reflect.Indirect(v) + v = v.Elem() if !r.started { columns, err := r.Columns() @@ -576,7 +613,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.fields = m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(r.fields); err != nil && !r.unsafe { - return fmt.Errorf("missing destination name %s", columns[f]) + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } r.values = make([]interface{}, len(columns)) r.started = true @@ -598,10 +635,14 @@ func (r *Rows) StructScan(dest interface{}) error { func Connect(driverName, dataSourceName string) (*DB, error) { db, err := Open(driverName, dataSourceName) if err != nil { - return db, err + return nil, err } err = db.Ping() - return db, err + if err != nil { + db.Close() + return nil, err + } + return db, nil } // MustConnect connects to a database and panics on error. @@ -626,6 +667,7 @@ func Preparex(p Preparer, query string) (*Stmt, error) { // into dest, which must be a slice. If the slice elements are scannable, then // the result set must have only one column. Otherwise, StructScan is used. // The *sql.Rows are closed automatically. +// Any placeholder parameters are replaced with supplied args. func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { rows, err := q.Queryx(query, args...) if err != nil { @@ -639,6 +681,8 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro // Get does a QueryRow using the provided Queryer, and scans the resulting row // to dest. If dest is scannable, the result must only have one column. Otherwise, // StructScan is used. Get will return sql.ErrNoRows like row.Scan would. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { r := q.QueryRowx(query, args...) return r.scanAny(dest, false) @@ -669,6 +713,7 @@ func LoadFile(e Execer, path string) (*sql.Result, error) { } // MustExec execs the query using e and panics if there was an error. +// Any placeholder parameters are replaced with supplied args. func MustExec(e Execer, query string, args ...interface{}) sql.Result { res, err := e.Exec(query, args...) if err != nil { @@ -691,6 +736,10 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { if r.err != nil { return r.err } + if r.rows == nil { + r.err = sql.ErrNoRows + return r.err + } defer r.rows.Close() v := reflect.ValueOf(dest) @@ -726,7 +775,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { fields := m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !r.unsafe { - return fmt.Errorf("missing destination name %s", columns[f]) + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values := make([]interface{}, len(columns)) @@ -744,7 +793,7 @@ func (r *Row) StructScan(dest interface{}) error { } // SliceScan a row, returning a []interface{} with values similar to MapScan. -// This function is primarly intended for use where the number of columns +// This function is primarily intended for use where the number of columns // is not known. Because you can pass an []interface{} directly to Scan, // it's recommended that you do that as it will not have to allocate new // slices per row. @@ -779,7 +828,7 @@ func SliceScan(r ColScanner) ([]interface{}, error) { // executes SQL from input). Please do not use this as a primary interface! // This will modify the map sent to it in place, so reuse the same map with // care. Columns which occur more than once in the result will overwrite -// eachother! +// each other! func MapScan(r ColScanner, dest map[string]interface{}) error { // ignore r.started, since we needn't use reflect for anything. columns, err := r.Columns() @@ -892,7 +941,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { fields := m.TraversalsByName(base, columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { - return fmt.Errorf("missing destination name %s", columns[f]) + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) @@ -902,6 +951,9 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { v = reflect.Indirect(vp) err = fieldsByTraversal(v, fields, values, true) + if err != nil { + return err + } // scan into the struct field pointers and append to our results err = rows.Scan(values...) @@ -919,6 +971,9 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { for rows.Next() { vp = reflect.New(base) err = rows.Scan(vp.Interface()) + if err != nil { + return err + } // append if isPtr { direct.Set(reflect.Append(direct, vp)) @@ -937,7 +992,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { // anyway) works on a rows object. // StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. -// StructScan will scan in the entire rows result, so if you need do not want to +// StructScan will scan in the entire rows result, so if you do not want to // allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. // If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. func StructScan(rows rowsi, dest interface{}) error { diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/sqlx_context.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/sqlx_context.go new file mode 100644 index 0000000..d58ff33 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/jmoiron/sqlx/sqlx_context.go @@ -0,0 +1,348 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" + "fmt" + "io/ioutil" + "path/filepath" + "reflect" +) + +// ConnectContext to a database and verify with a ping. +func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { + db, err := Open(driverName, dataSourceName) + if err != nil { + return db, err + } + err = db.PingContext(ctx) + return db, err +} + +// QueryerContext is an interface used by GetContext and SelectContext +type QueryerContext interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) + QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row +} + +// PreparerContext is an interface used by PreparexContext. +type PreparerContext interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +// ExecerContext is an interface used by MustExecContext and LoadFileContext +type ExecerContext interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// ExtContext is a union interface which can bind, query, and exec, with Context +// used by NamedQueryContext and NamedExecContext. +type ExtContext interface { + binder + QueryerContext + ExecerContext +} + +// SelectContext executes a query using the provided Queryer, and StructScans +// each row into dest, which must be a slice. If the slice elements are +// scannable, then the result set must have only one column. Otherwise, +// StructScan is used. The *sql.Rows are closed automatically. +// Any placeholder parameters are replaced with supplied args. +func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + rows, err := q.QueryxContext(ctx, query, args...) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// PreparexContext prepares a statement. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { + s, err := p.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err +} + +// GetContext does a QueryRow using the provided Queryer, and scans the +// resulting row to dest. If dest is scannable, the result must only have one +// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like +// row.Scan would. Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + r := q.QueryRowxContext(ctx, query, args...) + return r.scanAny(dest, false) +} + +// LoadFileContext exec's every statement in a file (as a single call to Exec). +// LoadFileContext may return a nil *sql.Result if errors are encountered +// locating or reading the file at path. LoadFile reads the entire file into +// memory, so it is not suitable for loading large data dumps, but can be useful +// for initializing schemas or loading indexes. +// +// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 +// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting +// this by requiring something with DriverName() and then attempting to split the +// queries will be difficult to get right, and its current driver-specific behavior +// is deemed at least not complex in its incorrectness. +func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { + realpath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + contents, err := ioutil.ReadFile(realpath) + if err != nil { + return nil, err + } + res, err := e.ExecContext(ctx, string(contents)) + return &res, err +} + +// MustExecContext execs the query using e and panics if there was an error. +// Any placeholder parameters are replaced with supplied args. +func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { + res, err := e.ExecContext(ctx, query, args...) + if err != nil { + panic(err) + } + return res +} + +// PrepareNamedContext returns an sqlx.NamedStmt +func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { + return prepareNamedContext(ctx, db, query) +} + +// NamedQueryContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { + return NamedQueryContext(ctx, db, query, arg) +} + +// NamedExecContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return NamedExecContext(ctx, db, query, arg) +} + +// SelectContext using this DB. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return SelectContext(ctx, db, dest, query, args...) +} + +// GetContext using this DB. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return GetContext(ctx, db, dest, query, args...) +} + +// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { + return PreparexContext(ctx, db, query) +} + +// QueryxContext queries the database and returns an *sqlx.Rows. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := db.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// QueryRowxContext queries the database and returns an *sqlx.Row. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.DB.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} +} + +// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead +// of an *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// MustBeginContext is canceled. +func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { + tx, err := db.BeginTxx(ctx, opts) + if err != nil { + panic(err) + } + return tx +} + +// MustExecContext (panic) runs MustExec using this database. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, db, query, args...) +} + +// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an +// *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// BeginxContext is canceled. +func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// StmtxContext returns a version of the prepared statement which runs within a +// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. +func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { + var s *sql.Stmt + switch v := stmt.(type) { + case Stmt: + s = v.Stmt + case *Stmt: + s = v.Stmt + case sql.Stmt: + s = &v + case *sql.Stmt: + s = v + default: + panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) + } + return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} +} + +// NamedStmtContext returns a version of the prepared statement which runs +// within a transaction. +func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { + return &NamedStmt{ + QueryString: stmt.QueryString, + Params: stmt.Params, + Stmt: tx.StmtxContext(ctx, stmt.Stmt), + } +} + +// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func (tx *Tx) PreparexContext(ctx context.Context, query string) (*Stmt, error) { + return PreparexContext(ctx, tx, query) +} + +// PrepareNamedContext returns an sqlx.NamedStmt +func (tx *Tx) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { + return prepareNamedContext(ctx, tx, query) +} + +// MustExecContext runs MustExecContext within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, tx, query, args...) +} + +// QueryxContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := tx.Tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err +} + +// SelectContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return SelectContext(ctx, tx, dest, query, args...) +} + +// GetContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return GetContext(ctx, tx, dest, query, args...) +} + +// QueryRowxContext within a transaction and context. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.Tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} +} + +// NamedExecContext using this Tx. +// Any named placeholder parameters are replaced with fields from arg. +func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return NamedExecContext(ctx, tx, query, arg) +} + +// SelectContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return SelectContext(ctx, &qStmt{s}, dest, "", args...) +} + +// GetContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return GetContext(ctx, &qStmt{s}, dest, "", args...) +} + +// MustExecContext (panic) using this statement. Note that the query portion of +// the error output will be blank, as Stmt does not expose its query. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { + return MustExecContext(ctx, &qStmt{s}, "", args...) +} + +// QueryRowxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { + qs := &qStmt{s} + return qs.QueryRowxContext(ctx, "", args...) +} + +// QueryxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { + qs := &qStmt{s} + return qs.QueryxContext(ctx, "", args...) +} + +func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return q.Stmt.QueryContext(ctx, args...) +} + +func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := q.Stmt.QueryContext(ctx, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err +} + +func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := q.Stmt.QueryContext(ctx, args...) + return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} +} + +func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return q.Stmt.ExecContext(ctx, args...) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/LICENSE similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/LICENSE rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/LICENSE diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/doc.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/doc.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/doc.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/doc.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/error.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/error.go similarity index 86% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/error.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/error.go index 8c51c45..b7df735 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/error.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/error.go @@ -124,6 +124,33 @@ func (e *Err) Error() string { return fmt.Sprintf("%s: %v", e.message, err) } +// Format implements fmt.Formatter +// When printing errors with %+v it also prints the stack trace. +// %#v unsurprisingly will print the real underlying type. +func (e *Err) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + fmt.Fprintf(s, "%s", ErrorStack(e)) + return + case s.Flag('#'): + // avoid infinite recursion by wrapping e into a type + // that doesn't implement Formatter. + fmt.Fprintf(s, "%#v", (*unformatter)(e)) + return + } + fallthrough + case 's': + fmt.Fprintf(s, "%s", e.Error()) + } +} + +// helper for Format +type unformatter Err + +func (unformatter) Format() { /* break the fmt.Formatter interface */ } + // SetLocation records the source location of the error at callDepth stack // frames above the call. func (e *Err) SetLocation(callDepth int) { diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/errortypes.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/errortypes.go similarity index 92% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/errortypes.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/errortypes.go index 10b3b19..9b731c4 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/errortypes.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/errortypes.go @@ -282,3 +282,28 @@ func IsMethodNotAllowed(err error) bool { _, ok := err.(*methodNotAllowed) return ok } + +// forbidden represents an error when a request cannot be completed because of +// missing privileges +type forbidden struct { + Err +} + +// Forbiddenf returns an error which satistifes IsForbidden() +func Forbiddenf(format string, args ...interface{}) error { + return &forbidden{wrap(nil, format, "", args...)} +} + +// NewForbidden returns an error which wraps err that satisfies +// IsForbidden(). +func NewForbidden(err error, msg string) error { + return &forbidden{wrap(err, msg, "")} +} + +// IsForbidden reports whether err was created with Forbiddenf() or +// NewForbidden(). +func IsForbidden(err error) bool { + err = Cause(err) + _, ok := err.(*forbidden) + return ok +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/functions.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/functions.go similarity index 99% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/functions.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/functions.go index 994208d..f86b09b 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/functions.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/functions.go @@ -8,7 +8,7 @@ import ( "strings" ) -// New is a drop in replacement for the standard libary errors module that records +// New is a drop in replacement for the standard library errors module that records // the location that the error is created. // // For example: diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/path.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/path.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/path.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/juju/errors/path.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/benchmark.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/benchmark.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/benchmark.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/benchmark.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/check.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/check.go similarity index 96% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/check.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/check.go index c99392a..fc535bc 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/check.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/check.go @@ -86,6 +86,7 @@ type C struct { logb *logger logw io.Writer done chan *C + parallel chan *C reason string mustFail bool tempDir *tempDir @@ -533,6 +534,7 @@ type RunConf struct { BenchmarkTime time.Duration // Defaults to 1 second BenchmarkMem bool KeepWorkDir bool + Exclude string } // Create a new suiteRunner able to run all methods in the given suite. @@ -577,6 +579,17 @@ func newSuiteRunner(suite interface{}, runConf *RunConf) *suiteRunner { } } + var excludeRegexp *regexp.Regexp + if conf.Exclude != "" { + if regexp, err := regexp.Compile(conf.Exclude); err != nil { + msg := "Bad exclude expression: " + err.Error() + runner.tracker.result.RunError = errors.New(msg) + return runner + } else { + excludeRegexp = regexp + } + } + for i := 0; i != suiteNumMethods; i++ { method := newMethod(suiteValue, i) switch method.Info.Name { @@ -597,7 +610,9 @@ func newSuiteRunner(suite interface{}, runConf *RunConf) *suiteRunner { continue } if filterRegexp == nil || method.matches(filterRegexp) { - runner.tests = append(runner.tests, method) + if excludeRegexp == nil || !method.matches(excludeRegexp) { + runner.tests = append(runner.tests, method) + } } } } @@ -611,13 +626,23 @@ func (runner *suiteRunner) run() *Result { if runner.checkFixtureArgs() { c := runner.runFixture(runner.setUpSuite, "", nil) if c == nil || c.status() == succeededSt { + var delayedC []*C for i := 0; i != len(runner.tests); i++ { - c := runner.runTest(runner.tests[i]) + c := runner.forkTest(runner.tests[i]) + select { + case <-c.done: + case <-c.parallel: + delayedC = append(delayedC, c) + } if c.status() == fixturePanickedSt { runner.skipTests(missedSt, runner.tests[i+1:]) break } } + // Wait those parallel tests finish. + for _, delayed := range delayedC { + <-delayed.done + } } else if c != nil && c.status() == skippedSt { runner.skipTests(skippedSt, runner.tests) } else { @@ -655,6 +680,7 @@ func (runner *suiteRunner) forkCall(method *methodType, kind funcKind, testName logw: logw, tempDir: runner.tempDir, done: make(chan *C, 1), + parallel: make(chan *C, 1), timer: timer{benchTime: runner.benchTime}, startTime: time.Now(), benchMem: runner.benchMem, diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/checkers.go similarity index 99% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/checkers.go index bac3387..3749545 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/checkers.go @@ -212,7 +212,7 @@ type hasLenChecker struct { // The HasLen checker verifies that the obtained value has the // provided length. In many cases this is superior to using Equals -// in conjuction with the len function because in case the check +// in conjunction with the len function because in case the check // fails the value itself will be printed, instead of its length, // providing more details for figuring the problem. // diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers2.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/checkers2.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers2.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/checkers2.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/compare.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/compare.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/compare.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/compare.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/helpers.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/helpers.go similarity index 98% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/helpers.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/helpers.go index 58a733b..68e861d 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/helpers.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/helpers.go @@ -76,6 +76,11 @@ func (c *C) Skip(reason string) { c.stopNow() } +// Parallel will mark the test run parallel within a test suite. +func (c *C) Parallel() { + c.parallel <- c +} + // ----------------------------------------------------------------------- // Basic logging. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/printer.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/printer.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/printer.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/printer.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/run.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/run.go similarity index 95% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/run.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/run.go index da8fd79..afa631f 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/run.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/pingcap/check/run.go @@ -42,8 +42,11 @@ var ( newBenchMem = flag.Bool("check.bmem", false, "Report memory benchmarks") newListFlag = flag.Bool("check.list", false, "List the names of all tests that will be run") newWorkFlag = flag.Bool("check.work", false, "Display and do not remove the test working directory") + newExcludeFlag = flag.String("check.exclude", "", "Regular expression to exclude tests to run") ) +var CustomVerboseFlag bool + // TestingT runs all test suites registered with the Suite function, // printing results to stdout, and reporting any failures back to // the "testing" package. @@ -54,12 +57,13 @@ func TestingT(testingT *testing.T) { } conf := &RunConf{ Filter: *oldFilterFlag + *newFilterFlag, - Verbose: *oldVerboseFlag || *newVerboseFlag, + Verbose: *oldVerboseFlag || *newVerboseFlag || CustomVerboseFlag, Stream: *oldStreamFlag || *newStreamFlag, Benchmark: *oldBenchFlag || *newBenchFlag, BenchmarkTime: benchTime, BenchmarkMem: *newBenchMem, KeepWorkDir: *oldWorkFlag || *newWorkFlag, + Exclude: *newExcludeFlag, } if *oldListFlag || *newListFlag { w := bufio.NewWriter(os.Stdout) diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/LICENSE similarity index 94% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/LICENSE rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/LICENSE index 488357b..926d549 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/LICENSE +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/LICENSE @@ -1,4 +1,4 @@ -Copyright (C) 2013-2016 by Maxim Bublis +Copyright (C) 2013-2018 by Maxim Bublis Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/codec.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/codec.go new file mode 100644 index 0000000..656892c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/codec.go @@ -0,0 +1,206 @@ +// Copyright (C) 2013-2018 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package uuid + +import ( + "bytes" + "encoding/hex" + "fmt" +) + +// FromBytes returns UUID converted from raw byte slice input. +// It will return error if the slice isn't 16 bytes long. +func FromBytes(input []byte) (u UUID, err error) { + err = u.UnmarshalBinary(input) + return +} + +// FromBytesOrNil returns UUID converted from raw byte slice input. +// Same behavior as FromBytes, but returns a Nil UUID on error. +func FromBytesOrNil(input []byte) UUID { + uuid, err := FromBytes(input) + if err != nil { + return Nil + } + return uuid +} + +// FromString returns UUID parsed from string input. +// Input is expected in a form accepted by UnmarshalText. +func FromString(input string) (u UUID, err error) { + err = u.UnmarshalText([]byte(input)) + return +} + +// FromStringOrNil returns UUID parsed from string input. +// Same behavior as FromString, but returns a Nil UUID on error. +func FromStringOrNil(input string) UUID { + uuid, err := FromString(input) + if err != nil { + return Nil + } + return uuid +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The encoding is the same as returned by String. +func (u UUID) MarshalText() (text []byte, err error) { + text = []byte(u.String()) + return +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// Following formats are supported: +// "6ba7b810-9dad-11d1-80b4-00c04fd430c8", +// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}", +// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" +// "6ba7b8109dad11d180b400c04fd430c8" +// ABNF for supported UUID text representation follows: +// uuid := canonical | hashlike | braced | urn +// plain := canonical | hashlike +// canonical := 4hexoct '-' 2hexoct '-' 2hexoct '-' 6hexoct +// hashlike := 12hexoct +// braced := '{' plain '}' +// urn := URN ':' UUID-NID ':' plain +// URN := 'urn' +// UUID-NID := 'uuid' +// 12hexoct := 6hexoct 6hexoct +// 6hexoct := 4hexoct 2hexoct +// 4hexoct := 2hexoct 2hexoct +// 2hexoct := hexoct hexoct +// hexoct := hexdig hexdig +// hexdig := '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9' | +// 'a' | 'b' | 'c' | 'd' | 'e' | 'f' | +// 'A' | 'B' | 'C' | 'D' | 'E' | 'F' +func (u *UUID) UnmarshalText(text []byte) (err error) { + switch len(text) { + case 32: + return u.decodeHashLike(text) + case 36: + return u.decodeCanonical(text) + case 38: + return u.decodeBraced(text) + case 41: + fallthrough + case 45: + return u.decodeURN(text) + default: + return fmt.Errorf("uuid: incorrect UUID length: %s", text) + } +} + +// decodeCanonical decodes UUID string in format +// "6ba7b810-9dad-11d1-80b4-00c04fd430c8". +func (u *UUID) decodeCanonical(t []byte) (err error) { + if t[8] != '-' || t[13] != '-' || t[18] != '-' || t[23] != '-' { + return fmt.Errorf("uuid: incorrect UUID format %s", t) + } + + src := t[:] + dst := u[:] + + for i, byteGroup := range byteGroups { + if i > 0 { + src = src[1:] // skip dash + } + _, err = hex.Decode(dst[:byteGroup/2], src[:byteGroup]) + if err != nil { + return + } + src = src[byteGroup:] + dst = dst[byteGroup/2:] + } + + return +} + +// decodeHashLike decodes UUID string in format +// "6ba7b8109dad11d180b400c04fd430c8". +func (u *UUID) decodeHashLike(t []byte) (err error) { + src := t[:] + dst := u[:] + + if _, err = hex.Decode(dst, src); err != nil { + return err + } + return +} + +// decodeBraced decodes UUID string in format +// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}" or in format +// "{6ba7b8109dad11d180b400c04fd430c8}". +func (u *UUID) decodeBraced(t []byte) (err error) { + l := len(t) + + if t[0] != '{' || t[l-1] != '}' { + return fmt.Errorf("uuid: incorrect UUID format %s", t) + } + + return u.decodePlain(t[1 : l-1]) +} + +// decodeURN decodes UUID string in format +// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" or in format +// "urn:uuid:6ba7b8109dad11d180b400c04fd430c8". +func (u *UUID) decodeURN(t []byte) (err error) { + total := len(t) + + urn_uuid_prefix := t[:9] + + if !bytes.Equal(urn_uuid_prefix, urnPrefix) { + return fmt.Errorf("uuid: incorrect UUID format: %s", t) + } + + return u.decodePlain(t[9:total]) +} + +// decodePlain decodes UUID string in canonical format +// "6ba7b810-9dad-11d1-80b4-00c04fd430c8" or in hash-like format +// "6ba7b8109dad11d180b400c04fd430c8". +func (u *UUID) decodePlain(t []byte) (err error) { + switch len(t) { + case 32: + return u.decodeHashLike(t) + case 36: + return u.decodeCanonical(t) + default: + return fmt.Errorf("uuid: incorrrect UUID length: %s", t) + } +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (u UUID) MarshalBinary() (data []byte, err error) { + data = u.Bytes() + return +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It will return error if the slice isn't 16 bytes long. +func (u *UUID) UnmarshalBinary(data []byte) (err error) { + if len(data) != Size { + err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data)) + return + } + copy(u[:], data) + + return +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/generator.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/generator.go new file mode 100644 index 0000000..3f2f1da --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/generator.go @@ -0,0 +1,239 @@ +// Copyright (C) 2013-2018 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package uuid + +import ( + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "encoding/binary" + "hash" + "net" + "os" + "sync" + "time" +) + +// Difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970). +const epochStart = 122192928000000000 + +var ( + global = newDefaultGenerator() + + epochFunc = unixTimeFunc + posixUID = uint32(os.Getuid()) + posixGID = uint32(os.Getgid()) +) + +// NewV1 returns UUID based on current timestamp and MAC address. +func NewV1() UUID { + return global.NewV1() +} + +// NewV2 returns DCE Security UUID based on POSIX UID/GID. +func NewV2(domain byte) UUID { + return global.NewV2(domain) +} + +// NewV3 returns UUID based on MD5 hash of namespace UUID and name. +func NewV3(ns UUID, name string) UUID { + return global.NewV3(ns, name) +} + +// NewV4 returns random generated UUID. +func NewV4() UUID { + return global.NewV4() +} + +// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name. +func NewV5(ns UUID, name string) UUID { + return global.NewV5(ns, name) +} + +// Generator provides interface for generating UUIDs. +type Generator interface { + NewV1() UUID + NewV2(domain byte) UUID + NewV3(ns UUID, name string) UUID + NewV4() UUID + NewV5(ns UUID, name string) UUID +} + +// Default generator implementation. +type generator struct { + storageOnce sync.Once + storageMutex sync.Mutex + + lastTime uint64 + clockSequence uint16 + hardwareAddr [6]byte +} + +func newDefaultGenerator() Generator { + return &generator{} +} + +// NewV1 returns UUID based on current timestamp and MAC address. +func (g *generator) NewV1() UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := g.getStorage() + + binary.BigEndian.PutUint32(u[0:], uint32(timeNow)) + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + + copy(u[10:], hardwareAddr) + + u.SetVersion(V1) + u.SetVariant(VariantRFC4122) + + return u +} + +// NewV2 returns DCE Security UUID based on POSIX UID/GID. +func (g *generator) NewV2(domain byte) UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := g.getStorage() + + switch domain { + case DomainPerson: + binary.BigEndian.PutUint32(u[0:], posixUID) + case DomainGroup: + binary.BigEndian.PutUint32(u[0:], posixGID) + } + + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + u[9] = domain + + copy(u[10:], hardwareAddr) + + u.SetVersion(V2) + u.SetVariant(VariantRFC4122) + + return u +} + +// NewV3 returns UUID based on MD5 hash of namespace UUID and name. +func (g *generator) NewV3(ns UUID, name string) UUID { + u := newFromHash(md5.New(), ns, name) + u.SetVersion(V3) + u.SetVariant(VariantRFC4122) + + return u +} + +// NewV4 returns random generated UUID. +func (g *generator) NewV4() UUID { + u := UUID{} + g.safeRandom(u[:]) + u.SetVersion(V4) + u.SetVariant(VariantRFC4122) + + return u +} + +// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name. +func (g *generator) NewV5(ns UUID, name string) UUID { + u := newFromHash(sha1.New(), ns, name) + u.SetVersion(V5) + u.SetVariant(VariantRFC4122) + + return u +} + +func (g *generator) initStorage() { + g.initClockSequence() + g.initHardwareAddr() +} + +func (g *generator) initClockSequence() { + buf := make([]byte, 2) + g.safeRandom(buf) + g.clockSequence = binary.BigEndian.Uint16(buf) +} + +func (g *generator) initHardwareAddr() { + interfaces, err := net.Interfaces() + if err == nil { + for _, iface := range interfaces { + if len(iface.HardwareAddr) >= 6 { + copy(g.hardwareAddr[:], iface.HardwareAddr) + return + } + } + } + + // Initialize hardwareAddr randomly in case + // of real network interfaces absence + g.safeRandom(g.hardwareAddr[:]) + + // Set multicast bit as recommended in RFC 4122 + g.hardwareAddr[0] |= 0x01 +} + +func (g *generator) safeRandom(dest []byte) { + if _, err := rand.Read(dest); err != nil { + panic(err) + } +} + +// Returns UUID v1/v2 storage state. +// Returns epoch timestamp, clock sequence, and hardware address. +func (g *generator) getStorage() (uint64, uint16, []byte) { + g.storageOnce.Do(g.initStorage) + + g.storageMutex.Lock() + defer g.storageMutex.Unlock() + + timeNow := epochFunc() + // Clock changed backwards since last UUID generation. + // Should increase clock sequence. + if timeNow <= g.lastTime { + g.clockSequence++ + } + g.lastTime = timeNow + + return timeNow, g.clockSequence, g.hardwareAddr[:] +} + +// Returns difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and current time. +// This is default epoch calculation function. +func unixTimeFunc() uint64 { + return epochStart + uint64(time.Now().UnixNano()/100) +} + +// Returns UUID based on hashing of namespace UUID and name. +func newFromHash(h hash.Hash, ns UUID, name string) UUID { + u := UUID{} + h.Write(ns[:]) + h.Write([]byte(name)) + copy(u[:], h.Sum(nil)) + + return u +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/sql.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/sql.go new file mode 100644 index 0000000..56759d3 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/sql.go @@ -0,0 +1,78 @@ +// Copyright (C) 2013-2018 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package uuid + +import ( + "database/sql/driver" + "fmt" +) + +// Value implements the driver.Valuer interface. +func (u UUID) Value() (driver.Value, error) { + return u.String(), nil +} + +// Scan implements the sql.Scanner interface. +// A 16-byte slice is handled by UnmarshalBinary, while +// a longer byte slice or a string is handled by UnmarshalText. +func (u *UUID) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + if len(src) == Size { + return u.UnmarshalBinary(src) + } + return u.UnmarshalText(src) + + case string: + return u.UnmarshalText([]byte(src)) + } + + return fmt.Errorf("uuid: cannot convert %T to UUID", src) +} + +// NullUUID can be used with the standard sql package to represent a +// UUID value that can be NULL in the database +type NullUUID struct { + UUID UUID + Valid bool +} + +// Value implements the driver.Valuer interface. +func (u NullUUID) Value() (driver.Value, error) { + if !u.Valid { + return nil, nil + } + // Delegate to UUID Value function + return u.UUID.Value() +} + +// Scan implements the sql.Scanner interface. +func (u *NullUUID) Scan(src interface{}) error { + if src == nil { + u.UUID, u.Valid = Nil, false + return nil + } + + // Delegate to UUID Scan function + u.Valid = true + return u.UUID.Scan(src) +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/uuid.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/uuid.go new file mode 100644 index 0000000..a2b8e2c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/satori/go.uuid/uuid.go @@ -0,0 +1,161 @@ +// Copyright (C) 2013-2018 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Package uuid provides implementation of Universally Unique Identifier (UUID). +// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and +// version 2 (as specified in DCE 1.1). +package uuid + +import ( + "bytes" + "encoding/hex" +) + +// Size of a UUID in bytes. +const Size = 16 + +// UUID representation compliant with specification +// described in RFC 4122. +type UUID [Size]byte + +// UUID versions +const ( + _ byte = iota + V1 + V2 + V3 + V4 + V5 +) + +// UUID layout variants. +const ( + VariantNCS byte = iota + VariantRFC4122 + VariantMicrosoft + VariantFuture +) + +// UUID DCE domains. +const ( + DomainPerson = iota + DomainGroup + DomainOrg +) + +// String parse helpers. +var ( + urnPrefix = []byte("urn:uuid:") + byteGroups = []int{8, 4, 4, 4, 12} +) + +// Nil is special form of UUID that is specified to have all +// 128 bits set to zero. +var Nil = UUID{} + +// Predefined namespace UUIDs. +var ( + NamespaceDNS = Must(FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8")) + NamespaceURL = Must(FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8")) + NamespaceOID = Must(FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8")) + NamespaceX500 = Must(FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8")) +) + +// Equal returns true if u1 and u2 equals, otherwise returns false. +func Equal(u1 UUID, u2 UUID) bool { + return bytes.Equal(u1[:], u2[:]) +} + +// Version returns algorithm version used to generate UUID. +func (u UUID) Version() byte { + return u[6] >> 4 +} + +// Variant returns UUID layout variant. +func (u UUID) Variant() byte { + switch { + case (u[8] >> 7) == 0x00: + return VariantNCS + case (u[8] >> 6) == 0x02: + return VariantRFC4122 + case (u[8] >> 5) == 0x06: + return VariantMicrosoft + case (u[8] >> 5) == 0x07: + fallthrough + default: + return VariantFuture + } +} + +// Bytes returns bytes slice representation of UUID. +func (u UUID) Bytes() []byte { + return u[:] +} + +// Returns canonical string representation of UUID: +// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. +func (u UUID) String() string { + buf := make([]byte, 36) + + hex.Encode(buf[0:8], u[0:4]) + buf[8] = '-' + hex.Encode(buf[9:13], u[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], u[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], u[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], u[10:]) + + return string(buf) +} + +// SetVersion sets version bits. +func (u *UUID) SetVersion(v byte) { + u[6] = (u[6] & 0x0f) | (v << 4) +} + +// SetVariant sets variant bits. +func (u *UUID) SetVariant(v byte) { + switch v { + case VariantNCS: + u[8] = (u[8]&(0xff>>1) | (0x00 << 7)) + case VariantRFC4122: + u[8] = (u[8]&(0xff>>2) | (0x02 << 6)) + case VariantMicrosoft: + u[8] = (u[8]&(0xff>>3) | (0x06 << 5)) + case VariantFuture: + fallthrough + default: + u[8] = (u[8]&(0xff>>3) | (0x07 << 5)) + } +} + +// Must is a helper that wraps a call to a function returning (UUID, error) +// and panics if the error is non-nil. It is intended for use in variable +// initializations such as +// var packageUUID = uuid.Must(uuid.FromString("123e4567-e89b-12d3-a456-426655440000")); +func Must(u UUID, err error) UUID { + if err != nil { + panic(err) + } + return u +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/LICENSE new file mode 100644 index 0000000..ad2148a --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/LICENSE @@ -0,0 +1,45 @@ +The MIT License (MIT) + +Copyright (c) 2015 Spring, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +- Based on https://github.com/oguzbilgic/fpd, which has the following license: +""" +The MIT License (MIT) + +Copyright (c) 2013 Oguz Bilgic + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/decimal-go.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/decimal-go.go new file mode 100644 index 0000000..e08a15c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/decimal-go.go @@ -0,0 +1,414 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Multiprecision decimal numbers. +// For floating-point formatting only; not general purpose. +// Only operations are assign and (binary) left/right shift. +// Can do binary floating point in multiprecision decimal precisely +// because 2 divides 10; cannot do decimal floating point +// in multiprecision binary precisely. +package decimal + +type decimal struct { + d [800]byte // digits, big-endian representation + nd int // number of digits used + dp int // decimal point + neg bool // negative flag + trunc bool // discarded nonzero digits beyond d[:nd] +} + +func (a *decimal) String() string { + n := 10 + a.nd + if a.dp > 0 { + n += a.dp + } + if a.dp < 0 { + n += -a.dp + } + + buf := make([]byte, n) + w := 0 + switch { + case a.nd == 0: + return "0" + + case a.dp <= 0: + // zeros fill space between decimal point and digits + buf[w] = '0' + w++ + buf[w] = '.' + w++ + w += digitZero(buf[w : w+-a.dp]) + w += copy(buf[w:], a.d[0:a.nd]) + + case a.dp < a.nd: + // decimal point in middle of digits + w += copy(buf[w:], a.d[0:a.dp]) + buf[w] = '.' + w++ + w += copy(buf[w:], a.d[a.dp:a.nd]) + + default: + // zeros fill space between digits and decimal point + w += copy(buf[w:], a.d[0:a.nd]) + w += digitZero(buf[w : w+a.dp-a.nd]) + } + return string(buf[0:w]) +} + +func digitZero(dst []byte) int { + for i := range dst { + dst[i] = '0' + } + return len(dst) +} + +// trim trailing zeros from number. +// (They are meaningless; the decimal point is tracked +// independent of the number of digits.) +func trim(a *decimal) { + for a.nd > 0 && a.d[a.nd-1] == '0' { + a.nd-- + } + if a.nd == 0 { + a.dp = 0 + } +} + +// Assign v to a. +func (a *decimal) Assign(v uint64) { + var buf [24]byte + + // Write reversed decimal in buf. + n := 0 + for v > 0 { + v1 := v / 10 + v -= 10 * v1 + buf[n] = byte(v + '0') + n++ + v = v1 + } + + // Reverse again to produce forward decimal in a.d. + a.nd = 0 + for n--; n >= 0; n-- { + a.d[a.nd] = buf[n] + a.nd++ + } + a.dp = a.nd + trim(a) +} + +// Maximum shift that we can do in one pass without overflow. +// A uint has 32 or 64 bits, and we have to be able to accommodate 9<> 63) +const maxShift = uintSize - 4 + +// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow. +func rightShift(a *decimal, k uint) { + r := 0 // read pointer + w := 0 // write pointer + + // Pick up enough leading digits to cover first shift. + var n uint + for ; n>>k == 0; r++ { + if r >= a.nd { + if n == 0 { + // a == 0; shouldn't get here, but handle anyway. + a.nd = 0 + return + } + for n>>k == 0 { + n = n * 10 + r++ + } + break + } + c := uint(a.d[r]) + n = n*10 + c - '0' + } + a.dp -= r - 1 + + var mask uint = (1 << k) - 1 + + // Pick up a digit, put down a digit. + for ; r < a.nd; r++ { + c := uint(a.d[r]) + dig := n >> k + n &= mask + a.d[w] = byte(dig + '0') + w++ + n = n*10 + c - '0' + } + + // Put down extra digits. + for n > 0 { + dig := n >> k + n &= mask + if w < len(a.d) { + a.d[w] = byte(dig + '0') + w++ + } else if dig > 0 { + a.trunc = true + } + n = n * 10 + } + + a.nd = w + trim(a) +} + +// Cheat sheet for left shift: table indexed by shift count giving +// number of new digits that will be introduced by that shift. +// +// For example, leftcheats[4] = {2, "625"}. That means that +// if we are shifting by 4 (multiplying by 16), it will add 2 digits +// when the string prefix is "625" through "999", and one fewer digit +// if the string prefix is "000" through "624". +// +// Credit for this trick goes to Ken. + +type leftCheat struct { + delta int // number of new digits + cutoff string // minus one digit if original < a. +} + +var leftcheats = []leftCheat{ + // Leading digits of 1/2^i = 5^i. + // 5^23 is not an exact 64-bit floating point number, + // so have to use bc for the math. + // Go up to 60 to be large enough for 32bit and 64bit platforms. + /* + seq 60 | sed 's/^/5^/' | bc | + awk 'BEGIN{ print "\t{ 0, \"\" }," } + { + log2 = log(2)/log(10) + printf("\t{ %d, \"%s\" },\t// * %d\n", + int(log2*NR+1), $0, 2**NR) + }' + */ + {0, ""}, + {1, "5"}, // * 2 + {1, "25"}, // * 4 + {1, "125"}, // * 8 + {2, "625"}, // * 16 + {2, "3125"}, // * 32 + {2, "15625"}, // * 64 + {3, "78125"}, // * 128 + {3, "390625"}, // * 256 + {3, "1953125"}, // * 512 + {4, "9765625"}, // * 1024 + {4, "48828125"}, // * 2048 + {4, "244140625"}, // * 4096 + {4, "1220703125"}, // * 8192 + {5, "6103515625"}, // * 16384 + {5, "30517578125"}, // * 32768 + {5, "152587890625"}, // * 65536 + {6, "762939453125"}, // * 131072 + {6, "3814697265625"}, // * 262144 + {6, "19073486328125"}, // * 524288 + {7, "95367431640625"}, // * 1048576 + {7, "476837158203125"}, // * 2097152 + {7, "2384185791015625"}, // * 4194304 + {7, "11920928955078125"}, // * 8388608 + {8, "59604644775390625"}, // * 16777216 + {8, "298023223876953125"}, // * 33554432 + {8, "1490116119384765625"}, // * 67108864 + {9, "7450580596923828125"}, // * 134217728 + {9, "37252902984619140625"}, // * 268435456 + {9, "186264514923095703125"}, // * 536870912 + {10, "931322574615478515625"}, // * 1073741824 + {10, "4656612873077392578125"}, // * 2147483648 + {10, "23283064365386962890625"}, // * 4294967296 + {10, "116415321826934814453125"}, // * 8589934592 + {11, "582076609134674072265625"}, // * 17179869184 + {11, "2910383045673370361328125"}, // * 34359738368 + {11, "14551915228366851806640625"}, // * 68719476736 + {12, "72759576141834259033203125"}, // * 137438953472 + {12, "363797880709171295166015625"}, // * 274877906944 + {12, "1818989403545856475830078125"}, // * 549755813888 + {13, "9094947017729282379150390625"}, // * 1099511627776 + {13, "45474735088646411895751953125"}, // * 2199023255552 + {13, "227373675443232059478759765625"}, // * 4398046511104 + {13, "1136868377216160297393798828125"}, // * 8796093022208 + {14, "5684341886080801486968994140625"}, // * 17592186044416 + {14, "28421709430404007434844970703125"}, // * 35184372088832 + {14, "142108547152020037174224853515625"}, // * 70368744177664 + {15, "710542735760100185871124267578125"}, // * 140737488355328 + {15, "3552713678800500929355621337890625"}, // * 281474976710656 + {15, "17763568394002504646778106689453125"}, // * 562949953421312 + {16, "88817841970012523233890533447265625"}, // * 1125899906842624 + {16, "444089209850062616169452667236328125"}, // * 2251799813685248 + {16, "2220446049250313080847263336181640625"}, // * 4503599627370496 + {16, "11102230246251565404236316680908203125"}, // * 9007199254740992 + {17, "55511151231257827021181583404541015625"}, // * 18014398509481984 + {17, "277555756156289135105907917022705078125"}, // * 36028797018963968 + {17, "1387778780781445675529539585113525390625"}, // * 72057594037927936 + {18, "6938893903907228377647697925567626953125"}, // * 144115188075855872 + {18, "34694469519536141888238489627838134765625"}, // * 288230376151711744 + {18, "173472347597680709441192448139190673828125"}, // * 576460752303423488 + {19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976 +} + +// Is the leading prefix of b lexicographically less than s? +func prefixIsLessThan(b []byte, s string) bool { + for i := 0; i < len(s); i++ { + if i >= len(b) { + return true + } + if b[i] != s[i] { + return b[i] < s[i] + } + } + return false +} + +// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow. +func leftShift(a *decimal, k uint) { + delta := leftcheats[k].delta + if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) { + delta-- + } + + r := a.nd // read index + w := a.nd + delta // write index + + // Pick up a digit, put down a digit. + var n uint + for r--; r >= 0; r-- { + n += (uint(a.d[r]) - '0') << k + quo := n / 10 + rem := n - 10*quo + w-- + if w < len(a.d) { + a.d[w] = byte(rem + '0') + } else if rem != 0 { + a.trunc = true + } + n = quo + } + + // Put down extra digits. + for n > 0 { + quo := n / 10 + rem := n - 10*quo + w-- + if w < len(a.d) { + a.d[w] = byte(rem + '0') + } else if rem != 0 { + a.trunc = true + } + n = quo + } + + a.nd += delta + if a.nd >= len(a.d) { + a.nd = len(a.d) + } + a.dp += delta + trim(a) +} + +// Binary shift left (k > 0) or right (k < 0). +func (a *decimal) Shift(k int) { + switch { + case a.nd == 0: + // nothing to do: a == 0 + case k > 0: + for k > maxShift { + leftShift(a, maxShift) + k -= maxShift + } + leftShift(a, uint(k)) + case k < 0: + for k < -maxShift { + rightShift(a, maxShift) + k += maxShift + } + rightShift(a, uint(-k)) + } +} + +// If we chop a at nd digits, should we round up? +func shouldRoundUp(a *decimal, nd int) bool { + if nd < 0 || nd >= a.nd { + return false + } + if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even + // if we truncated, a little higher than what's recorded - always round up + if a.trunc { + return true + } + return nd > 0 && (a.d[nd-1]-'0')%2 != 0 + } + // not halfway - digit tells all + return a.d[nd] >= '5' +} + +// Round a to nd digits (or fewer). +// If nd is zero, it means we're rounding +// just to the left of the digits, as in +// 0.09 -> 0.1. +func (a *decimal) Round(nd int) { + if nd < 0 || nd >= a.nd { + return + } + if shouldRoundUp(a, nd) { + a.RoundUp(nd) + } else { + a.RoundDown(nd) + } +} + +// Round a down to nd digits (or fewer). +func (a *decimal) RoundDown(nd int) { + if nd < 0 || nd >= a.nd { + return + } + a.nd = nd + trim(a) +} + +// Round a up to nd digits (or fewer). +func (a *decimal) RoundUp(nd int) { + if nd < 0 || nd >= a.nd { + return + } + + // round up + for i := nd - 1; i >= 0; i-- { + c := a.d[i] + if c < '9' { // can stop after this digit + a.d[i]++ + a.nd = i + 1 + return + } + } + + // Number is all 9s. + // Change to single 1 with adjusted decimal point. + a.d[0] = '1' + a.nd = 1 + a.dp++ +} + +// Extract integer part, rounded appropriately. +// No guarantees about overflow. +func (a *decimal) RoundedInteger() uint64 { + if a.dp > 20 { + return 0xFFFFFFFFFFFFFFFF + } + var i int + n := uint64(0) + for i = 0; i < a.dp && i < a.nd; i++ { + n = n*10 + uint64(a.d[i]-'0') + } + for ; i < a.dp; i++ { + n *= 10 + } + if shouldRoundUp(a, a.dp) { + n++ + } + return n +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/decimal.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/decimal.go new file mode 100644 index 0000000..134ece2 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/decimal.go @@ -0,0 +1,1434 @@ +// Package decimal implements an arbitrary precision fixed-point decimal. +// +// To use as part of a struct: +// +// type Struct struct { +// Number Decimal +// } +// +// The zero-value of a Decimal is 0, as you would expect. +// +// The best way to create a new Decimal is to use decimal.NewFromString, ex: +// +// n, err := decimal.NewFromString("-123.4567") +// n.String() // output: "-123.4567" +// +// NOTE: This can "only" represent numbers with a maximum of 2^31 digits +// after the decimal point. +package decimal + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "math/big" + "strconv" + "strings" +) + +// DivisionPrecision is the number of decimal places in the result when it +// doesn't divide exactly. +// +// Example: +// +// d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3) +// d1.String() // output: "0.6666666666666667" +// d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000) +// d2.String() // output: "0.0000666666666667" +// d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3) +// d3.String() // output: "6666.6666666666666667" +// decimal.DivisionPrecision = 3 +// d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3) +// d4.String() // output: "0.667" +// +var DivisionPrecision = 16 + +// MarshalJSONWithoutQuotes should be set to true if you want the decimal to +// be JSON marshaled as a number, instead of as a string. +// WARNING: this is dangerous for decimals with many digits, since many JSON +// unmarshallers (ex: Javascript's) will unmarshal JSON numbers to IEEE 754 +// double-precision floating point numbers, which means you can potentially +// silently lose precision. +var MarshalJSONWithoutQuotes = false + +// Zero constant, to make computations faster. +var Zero = New(0, 1) + +// fiveDec used in Cash Rounding +var fiveDec = New(5, 0) + +var zeroInt = big.NewInt(0) +var oneInt = big.NewInt(1) +var twoInt = big.NewInt(2) +var fourInt = big.NewInt(4) +var fiveInt = big.NewInt(5) +var tenInt = big.NewInt(10) +var twentyInt = big.NewInt(20) + +// Decimal represents a fixed-point decimal. It is immutable. +// number = value * 10 ^ exp +type Decimal struct { + value *big.Int + + // NOTE(vadim): this must be an int32, because we cast it to float64 during + // calculations. If exp is 64 bit, we might lose precision. + // If we cared about being able to represent every possible decimal, we + // could make exp a *big.Int but it would hurt performance and numbers + // like that are unrealistic. + exp int32 +} + +// New returns a new fixed-point decimal, value * 10 ^ exp. +func New(value int64, exp int32) Decimal { + return Decimal{ + value: big.NewInt(value), + exp: exp, + } +} + +// NewFromBigInt returns a new Decimal from a big.Int, value * 10 ^ exp +func NewFromBigInt(value *big.Int, exp int32) Decimal { + return Decimal{ + value: big.NewInt(0).Set(value), + exp: exp, + } +} + +// NewFromString returns a new Decimal from a string representation. +// +// Example: +// +// d, err := NewFromString("-123.45") +// d2, err := NewFromString(".0001") +// +func NewFromString(value string) (Decimal, error) { + originalInput := value + var intString string + var exp int64 + + // Check if number is using scientific notation + eIndex := strings.IndexAny(value, "Ee") + if eIndex != -1 { + expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32) + if err != nil { + if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value) + } + return Decimal{}, fmt.Errorf("can't convert %s to decimal: exponent is not numeric", value) + } + value = value[:eIndex] + exp = expInt + } + + parts := strings.Split(value, ".") + if len(parts) == 1 { + // There is no decimal point, we can just parse the original string as + // an int + intString = value + } else if len(parts) == 2 { + // strip the insignificant digits for more accurate comparisons. + decimalPart := strings.TrimRight(parts[1], "0") + intString = parts[0] + decimalPart + expInt := -len(decimalPart) + exp += int64(expInt) + } else { + return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) + } + + dValue := new(big.Int) + _, ok := dValue.SetString(intString, 10) + if !ok { + return Decimal{}, fmt.Errorf("can't convert %s to decimal", value) + } + + if exp < math.MinInt32 || exp > math.MaxInt32 { + // NOTE(vadim): I doubt a string could realistically be this long + return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", originalInput) + } + + return Decimal{ + value: dValue, + exp: int32(exp), + }, nil +} + +// RequireFromString returns a new Decimal from a string representation +// or panics if NewFromString would have returned an error. +// +// Example: +// +// d := RequireFromString("-123.45") +// d2 := RequireFromString(".0001") +// +func RequireFromString(value string) Decimal { + dec, err := NewFromString(value) + if err != nil { + panic(err) + } + return dec +} + +// NewFromFloat converts a float64 to Decimal. +// +// The converted number will contain the number of significant digits that can be +// represented in a float with reliable roundtrip. +// This is typically 15 digits, but may be more in some cases. +// See https://www.exploringbinary.com/decimal-precision-of-binary-floating-point-numbers/ for more information. +// +// For slightly faster conversion, use NewFromFloatWithExponent where you can specify the precision in absolute terms. +// +// NOTE: this will panic on NaN, +/-inf +func NewFromFloat(value float64) Decimal { + if value == 0 { + return New(0, 0) + } + return newFromFloat(value, math.Float64bits(value), &float64info) +} + +// NewFromFloat converts a float32 to Decimal. +// +// The converted number will contain the number of significant digits that can be +// represented in a float with reliable roundtrip. +// This is typically 6-8 digits depending on the input. +// See https://www.exploringbinary.com/decimal-precision-of-binary-floating-point-numbers/ for more information. +// +// For slightly faster conversion, use NewFromFloatWithExponent where you can specify the precision in absolute terms. +// +// NOTE: this will panic on NaN, +/-inf +func NewFromFloat32(value float32) Decimal { + if value == 0 { + return New(0, 0) + } + // XOR is workaround for https://github.com/golang/go/issues/26285 + a := math.Float32bits(value) ^ 0x80808080 + return newFromFloat(float64(value), uint64(a)^0x80808080, &float32info) +} + +func newFromFloat(val float64, bits uint64, flt *floatInfo) Decimal { + if math.IsNaN(val) || math.IsInf(val, 0) { + panic(fmt.Sprintf("Cannot create a Decimal from %v", val)) + } + exp := int(bits>>flt.mantbits) & (1<>(flt.expbits+flt.mantbits) != 0 + + roundShortest(&d, mant, exp, flt) + // If less than 19 digits, we can do calculation in an int64. + if d.nd < 19 { + tmp := int64(0) + m := int64(1) + for i := d.nd - 1; i >= 0; i-- { + tmp += m * int64(d.d[i]-'0') + m *= 10 + } + if d.neg { + tmp *= -1 + } + return Decimal{value: big.NewInt(tmp), exp: int32(d.dp) - int32(d.nd)} + } + dValue := new(big.Int) + dValue, ok := dValue.SetString(string(d.d[:d.nd]), 10) + if ok { + return Decimal{value: dValue, exp: int32(d.dp) - int32(d.nd)} + } + + return NewFromFloatWithExponent(val, int32(d.dp)-int32(d.nd)) +} + +// NewFromFloatWithExponent converts a float64 to Decimal, with an arbitrary +// number of fractional digits. +// +// Example: +// +// NewFromFloatWithExponent(123.456, -2).String() // output: "123.46" +// +func NewFromFloatWithExponent(value float64, exp int32) Decimal { + if math.IsNaN(value) || math.IsInf(value, 0) { + panic(fmt.Sprintf("Cannot create a Decimal from %v", value)) + } + + bits := math.Float64bits(value) + mant := bits & (1<<52 - 1) + exp2 := int32((bits >> 52) & (1<<11 - 1)) + sign := bits >> 63 + + if exp2 == 0 { + // specials + if mant == 0 { + return Decimal{} + } else { + // subnormal + exp2++ + } + } else { + // normal + mant |= 1 << 52 + } + + exp2 -= 1023 + 52 + + // normalizing base-2 values + for mant&1 == 0 { + mant = mant >> 1 + exp2++ + } + + // maximum number of fractional base-10 digits to represent 2^N exactly cannot be more than -N if N<0 + if exp < 0 && exp < exp2 { + if exp2 < 0 { + exp = exp2 + } else { + exp = 0 + } + } + + // representing 10^M * 2^N as 5^M * 2^(M+N) + exp2 -= exp + + temp := big.NewInt(1) + dMant := big.NewInt(int64(mant)) + + // applying 5^M + if exp > 0 { + temp = temp.SetInt64(int64(exp)) + temp = temp.Exp(fiveInt, temp, nil) + } else if exp < 0 { + temp = temp.SetInt64(-int64(exp)) + temp = temp.Exp(fiveInt, temp, nil) + dMant = dMant.Mul(dMant, temp) + temp = temp.SetUint64(1) + } + + // applying 2^(M+N) + if exp2 > 0 { + dMant = dMant.Lsh(dMant, uint(exp2)) + } else if exp2 < 0 { + temp = temp.Lsh(temp, uint(-exp2)) + } + + // rounding and downscaling + if exp > 0 || exp2 < 0 { + halfDown := new(big.Int).Rsh(temp, 1) + dMant = dMant.Add(dMant, halfDown) + dMant = dMant.Quo(dMant, temp) + } + + if sign == 1 { + dMant = dMant.Neg(dMant) + } + + return Decimal{ + value: dMant, + exp: exp, + } +} + +// rescale returns a rescaled version of the decimal. Returned +// decimal may be less precise if the given exponent is bigger +// than the initial exponent of the Decimal. +// NOTE: this will truncate, NOT round +// +// Example: +// +// d := New(12345, -4) +// d2 := d.rescale(-1) +// d3 := d2.rescale(-4) +// println(d1) +// println(d2) +// println(d3) +// +// Output: +// +// 1.2345 +// 1.2 +// 1.2000 +// +func (d Decimal) rescale(exp int32) Decimal { + d.ensureInitialized() + // NOTE(vadim): must convert exps to float64 before - to prevent overflow + diff := math.Abs(float64(exp) - float64(d.exp)) + value := new(big.Int).Set(d.value) + + expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil) + if exp > d.exp { + value = value.Quo(value, expScale) + } else if exp < d.exp { + value = value.Mul(value, expScale) + } + + return Decimal{ + value: value, + exp: exp, + } +} + +// Abs returns the absolute value of the decimal. +func (d Decimal) Abs() Decimal { + d.ensureInitialized() + d2Value := new(big.Int).Abs(d.value) + return Decimal{ + value: d2Value, + exp: d.exp, + } +} + +// Add returns d + d2. +func (d Decimal) Add(d2 Decimal) Decimal { + baseScale := min(d.exp, d2.exp) + rd := d.rescale(baseScale) + rd2 := d2.rescale(baseScale) + + d3Value := new(big.Int).Add(rd.value, rd2.value) + return Decimal{ + value: d3Value, + exp: baseScale, + } +} + +// Sub returns d - d2. +func (d Decimal) Sub(d2 Decimal) Decimal { + baseScale := min(d.exp, d2.exp) + rd := d.rescale(baseScale) + rd2 := d2.rescale(baseScale) + + d3Value := new(big.Int).Sub(rd.value, rd2.value) + return Decimal{ + value: d3Value, + exp: baseScale, + } +} + +// Neg returns -d. +func (d Decimal) Neg() Decimal { + d.ensureInitialized() + val := new(big.Int).Neg(d.value) + return Decimal{ + value: val, + exp: d.exp, + } +} + +// Mul returns d * d2. +func (d Decimal) Mul(d2 Decimal) Decimal { + d.ensureInitialized() + d2.ensureInitialized() + + expInt64 := int64(d.exp) + int64(d2.exp) + if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 { + // NOTE(vadim): better to panic than give incorrect results, as + // Decimals are usually used for money + panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64)) + } + + d3Value := new(big.Int).Mul(d.value, d2.value) + return Decimal{ + value: d3Value, + exp: int32(expInt64), + } +} + +// Shift shifts the decimal in base 10. +// It shifts left when shift is positive and right if shift is negative. +// In simpler terms, the given value for shift is added to the exponent +// of the decimal. +func (d Decimal) Shift(shift int32) Decimal { + d.ensureInitialized() + return Decimal{ + value: new(big.Int).Set(d.value), + exp: d.exp + shift, + } +} + +// Div returns d / d2. If it doesn't divide exactly, the result will have +// DivisionPrecision digits after the decimal point. +func (d Decimal) Div(d2 Decimal) Decimal { + return d.DivRound(d2, int32(DivisionPrecision)) +} + +// QuoRem does divsion with remainder +// d.QuoRem(d2,precision) returns quotient q and remainder r such that +// d = d2 * q + r, q an integer multiple of 10^(-precision) +// 0 <= r < abs(d2) * 10 ^(-precision) if d>=0 +// 0 >= r > -abs(d2) * 10 ^(-precision) if d<0 +// Note that precision<0 is allowed as input. +func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) { + d.ensureInitialized() + d2.ensureInitialized() + if d2.value.Sign() == 0 { + panic("decimal division by 0") + } + scale := -precision + e := int64(d.exp - d2.exp - scale) + if e > math.MaxInt32 || e < math.MinInt32 { + panic("overflow in decimal QuoRem") + } + var aa, bb, expo big.Int + var scalerest int32 + // d = a 10^ea + // d2 = b 10^eb + if e < 0 { + aa = *d.value + expo.SetInt64(-e) + bb.Exp(tenInt, &expo, nil) + bb.Mul(d2.value, &bb) + scalerest = d.exp + // now aa = a + // bb = b 10^(scale + eb - ea) + } else { + expo.SetInt64(e) + aa.Exp(tenInt, &expo, nil) + aa.Mul(d.value, &aa) + bb = *d2.value + scalerest = scale + d2.exp + // now aa = a ^ (ea - eb - scale) + // bb = b + } + var q, r big.Int + q.QuoRem(&aa, &bb, &r) + dq := Decimal{value: &q, exp: scale} + dr := Decimal{value: &r, exp: scalerest} + return dq, dr +} + +// DivRound divides and rounds to a given precision +// i.e. to an integer multiple of 10^(-precision) +// for a positive quotient digit 5 is rounded up, away from 0 +// if the quotient is negative then digit 5 is rounded down, away from 0 +// Note that precision<0 is allowed as input. +func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal { + // QuoRem already checks initialization + q, r := d.QuoRem(d2, precision) + // the actual rounding decision is based on comparing r*10^precision and d2/2 + // instead compare 2 r 10 ^precision and d2 + var rv2 big.Int + rv2.Abs(r.value) + rv2.Lsh(&rv2, 1) + // now rv2 = abs(r.value) * 2 + r2 := Decimal{value: &rv2, exp: r.exp + precision} + // r2 is now 2 * r * 10 ^ precision + var c = r2.Cmp(d2.Abs()) + + if c < 0 { + return q + } + + if d.value.Sign()*d2.value.Sign() < 0 { + return q.Sub(New(1, -precision)) + } + + return q.Add(New(1, -precision)) +} + +// Mod returns d % d2. +func (d Decimal) Mod(d2 Decimal) Decimal { + quo := d.Div(d2).Truncate(0) + return d.Sub(d2.Mul(quo)) +} + +// Pow returns d to the power d2 +func (d Decimal) Pow(d2 Decimal) Decimal { + var temp Decimal + if d2.IntPart() == 0 { + return NewFromFloat(1) + } + temp = d.Pow(d2.Div(NewFromFloat(2))) + if d2.IntPart()%2 == 0 { + return temp.Mul(temp) + } + if d2.IntPart() > 0 { + return temp.Mul(temp).Mul(d) + } + return temp.Mul(temp).Div(d) +} + +// Cmp compares the numbers represented by d and d2 and returns: +// +// -1 if d < d2 +// 0 if d == d2 +// +1 if d > d2 +// +func (d Decimal) Cmp(d2 Decimal) int { + d.ensureInitialized() + d2.ensureInitialized() + + if d.exp == d2.exp { + return d.value.Cmp(d2.value) + } + + baseExp := min(d.exp, d2.exp) + rd := d.rescale(baseExp) + rd2 := d2.rescale(baseExp) + + return rd.value.Cmp(rd2.value) +} + +// Equal returns whether the numbers represented by d and d2 are equal. +func (d Decimal) Equal(d2 Decimal) bool { + return d.Cmp(d2) == 0 +} + +// Equals is deprecated, please use Equal method instead +func (d Decimal) Equals(d2 Decimal) bool { + return d.Equal(d2) +} + +// GreaterThan (GT) returns true when d is greater than d2. +func (d Decimal) GreaterThan(d2 Decimal) bool { + return d.Cmp(d2) == 1 +} + +// GreaterThanOrEqual (GTE) returns true when d is greater than or equal to d2. +func (d Decimal) GreaterThanOrEqual(d2 Decimal) bool { + cmp := d.Cmp(d2) + return cmp == 1 || cmp == 0 +} + +// LessThan (LT) returns true when d is less than d2. +func (d Decimal) LessThan(d2 Decimal) bool { + return d.Cmp(d2) == -1 +} + +// LessThanOrEqual (LTE) returns true when d is less than or equal to d2. +func (d Decimal) LessThanOrEqual(d2 Decimal) bool { + cmp := d.Cmp(d2) + return cmp == -1 || cmp == 0 +} + +// Sign returns: +// +// -1 if d < 0 +// 0 if d == 0 +// +1 if d > 0 +// +func (d Decimal) Sign() int { + if d.value == nil { + return 0 + } + return d.value.Sign() +} + +// IsPositive return +// +// true if d > 0 +// false if d == 0 +// false if d < 0 +func (d Decimal) IsPositive() bool { + return d.Sign() == 1 +} + +// IsNegative return +// +// true if d < 0 +// false if d == 0 +// false if d > 0 +func (d Decimal) IsNegative() bool { + return d.Sign() == -1 +} + +// IsZero return +// +// true if d == 0 +// false if d > 0 +// false if d < 0 +func (d Decimal) IsZero() bool { + return d.Sign() == 0 +} + +// Exponent returns the exponent, or scale component of the decimal. +func (d Decimal) Exponent() int32 { + return d.exp +} + +// Coefficient returns the coefficient of the decimal. It is scaled by 10^Exponent() +func (d Decimal) Coefficient() *big.Int { + // we copy the coefficient so that mutating the result does not mutate the + // Decimal. + return big.NewInt(0).Set(d.value) +} + +// IntPart returns the integer component of the decimal. +func (d Decimal) IntPart() int64 { + scaledD := d.rescale(0) + return scaledD.value.Int64() +} + +// Rat returns a rational number representation of the decimal. +func (d Decimal) Rat() *big.Rat { + d.ensureInitialized() + if d.exp <= 0 { + // NOTE(vadim): must negate after casting to prevent int32 overflow + denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil) + return new(big.Rat).SetFrac(d.value, denom) + } + + mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil) + num := new(big.Int).Mul(d.value, mul) + return new(big.Rat).SetFrac(num, oneInt) +} + +// Float64 returns the nearest float64 value for d and a bool indicating +// whether f represents d exactly. +// For more details, see the documentation for big.Rat.Float64 +func (d Decimal) Float64() (f float64, exact bool) { + return d.Rat().Float64() +} + +// String returns the string representation of the decimal +// with the fixed point. +// +// Example: +// +// d := New(-12345, -3) +// println(d.String()) +// +// Output: +// +// -12.345 +// +func (d Decimal) String() string { + return d.string(true) +} + +// StringFixed returns a rounded fixed-point string with places digits after +// the decimal point. +// +// Example: +// +// NewFromFloat(0).StringFixed(2) // output: "0.00" +// NewFromFloat(0).StringFixed(0) // output: "0" +// NewFromFloat(5.45).StringFixed(0) // output: "5" +// NewFromFloat(5.45).StringFixed(1) // output: "5.5" +// NewFromFloat(5.45).StringFixed(2) // output: "5.45" +// NewFromFloat(5.45).StringFixed(3) // output: "5.450" +// NewFromFloat(545).StringFixed(-1) // output: "550" +// +func (d Decimal) StringFixed(places int32) string { + rounded := d.Round(places) + return rounded.string(false) +} + +// StringFixedBank returns a banker rounded fixed-point string with places digits +// after the decimal point. +// +// Example: +// +// NewFromFloat(0).StringFixed(2) // output: "0.00" +// NewFromFloat(0).StringFixed(0) // output: "0" +// NewFromFloat(5.45).StringFixed(0) // output: "5" +// NewFromFloat(5.45).StringFixed(1) // output: "5.4" +// NewFromFloat(5.45).StringFixed(2) // output: "5.45" +// NewFromFloat(5.45).StringFixed(3) // output: "5.450" +// NewFromFloat(545).StringFixed(-1) // output: "550" +// +func (d Decimal) StringFixedBank(places int32) string { + rounded := d.RoundBank(places) + return rounded.string(false) +} + +// StringFixedCash returns a Swedish/Cash rounded fixed-point string. For +// more details see the documentation at function RoundCash. +func (d Decimal) StringFixedCash(interval uint8) string { + rounded := d.RoundCash(interval) + return rounded.string(false) +} + +// Round rounds the decimal to places decimal places. +// If places < 0, it will round the integer part to the nearest 10^(-places). +// +// Example: +// +// NewFromFloat(5.45).Round(1).String() // output: "5.5" +// NewFromFloat(545).Round(-1).String() // output: "550" +// +func (d Decimal) Round(places int32) Decimal { + // truncate to places + 1 + ret := d.rescale(-places - 1) + + // add sign(d) * 0.5 + if ret.value.Sign() < 0 { + ret.value.Sub(ret.value, fiveInt) + } else { + ret.value.Add(ret.value, fiveInt) + } + + // floor for positive numbers, ceil for negative numbers + _, m := ret.value.DivMod(ret.value, tenInt, new(big.Int)) + ret.exp++ + if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 { + ret.value.Add(ret.value, oneInt) + } + + return ret +} + +// RoundBank rounds the decimal to places decimal places. +// If the final digit to round is equidistant from the nearest two integers the +// rounded value is taken as the even number +// +// If places < 0, it will round the integer part to the nearest 10^(-places). +// +// Examples: +// +// NewFromFloat(5.45).Round(1).String() // output: "5.4" +// NewFromFloat(545).Round(-1).String() // output: "540" +// NewFromFloat(5.46).Round(1).String() // output: "5.5" +// NewFromFloat(546).Round(-1).String() // output: "550" +// NewFromFloat(5.55).Round(1).String() // output: "5.6" +// NewFromFloat(555).Round(-1).String() // output: "560" +// +func (d Decimal) RoundBank(places int32) Decimal { + + round := d.Round(places) + remainder := d.Sub(round).Abs() + + half := New(5, -places-1) + if remainder.Cmp(half) == 0 && round.value.Bit(0) != 0 { + if round.value.Sign() < 0 { + round.value.Add(round.value, oneInt) + } else { + round.value.Sub(round.value, oneInt) + } + } + + return round +} + +// RoundCash aka Cash/Penny/öre rounding rounds decimal to a specific +// interval. The amount payable for a cash transaction is rounded to the nearest +// multiple of the minimum currency unit available. The following intervals are +// available: 5, 10, 15, 25, 50 and 100; any other number throws a panic. +// 5: 5 cent rounding 3.43 => 3.45 +// 10: 10 cent rounding 3.45 => 3.50 (5 gets rounded up) +// 15: 10 cent rounding 3.45 => 3.40 (5 gets rounded down) +// 25: 25 cent rounding 3.41 => 3.50 +// 50: 50 cent rounding 3.75 => 4.00 +// 100: 100 cent rounding 3.50 => 4.00 +// For more details: https://en.wikipedia.org/wiki/Cash_rounding +func (d Decimal) RoundCash(interval uint8) Decimal { + var iVal *big.Int + switch interval { + case 5: + iVal = twentyInt + case 10: + iVal = tenInt + case 15: + if d.exp < 0 { + // TODO: optimize and reduce allocations + orgExp := d.exp + dOne := New(10^-int64(orgExp), orgExp) + d2 := d + d2.exp = 0 + if d2.Mod(fiveDec).Equal(Zero) { + d2.exp = orgExp + d2 = d2.Sub(dOne) + d = d2 + } + } + iVal = tenInt + case 25: + iVal = fourInt + case 50: + iVal = twoInt + case 100: + iVal = oneInt + default: + panic(fmt.Sprintf("Decimal does not support this Cash rounding interval `%d`. Supported: 5, 10, 15, 25, 50, 100", interval)) + } + dVal := Decimal{ + value: iVal, + } + // TODO: optimize those calculations to reduce the high allocations (~29 allocs). + return d.Mul(dVal).Round(0).Div(dVal).Truncate(2) +} + +// Floor returns the nearest integer value less than or equal to d. +func (d Decimal) Floor() Decimal { + d.ensureInitialized() + + if d.exp >= 0 { + return d + } + + exp := big.NewInt(10) + + // NOTE(vadim): must negate after casting to prevent int32 overflow + exp.Exp(exp, big.NewInt(-int64(d.exp)), nil) + + z := new(big.Int).Div(d.value, exp) + return Decimal{value: z, exp: 0} +} + +// Ceil returns the nearest integer value greater than or equal to d. +func (d Decimal) Ceil() Decimal { + d.ensureInitialized() + + if d.exp >= 0 { + return d + } + + exp := big.NewInt(10) + + // NOTE(vadim): must negate after casting to prevent int32 overflow + exp.Exp(exp, big.NewInt(-int64(d.exp)), nil) + + z, m := new(big.Int).DivMod(d.value, exp, new(big.Int)) + if m.Cmp(zeroInt) != 0 { + z.Add(z, oneInt) + } + return Decimal{value: z, exp: 0} +} + +// Truncate truncates off digits from the number, without rounding. +// +// NOTE: precision is the last digit that will not be truncated (must be >= 0). +// +// Example: +// +// decimal.NewFromString("123.456").Truncate(2).String() // "123.45" +// +func (d Decimal) Truncate(precision int32) Decimal { + d.ensureInitialized() + if precision >= 0 && -precision > d.exp { + return d.rescale(-precision) + } + return d +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { + if string(decimalBytes) == "null" { + return nil + } + + str, err := unquoteIfQuoted(decimalBytes) + if err != nil { + return fmt.Errorf("Error decoding string '%s': %s", decimalBytes, err) + } + + decimal, err := NewFromString(str) + *d = decimal + if err != nil { + return fmt.Errorf("Error decoding string '%s': %s", str, err) + } + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (d Decimal) MarshalJSON() ([]byte, error) { + var str string + if MarshalJSONWithoutQuotes { + str = d.String() + } else { + str = "\"" + d.String() + "\"" + } + return []byte(str), nil +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation +// is already used when encoding to text, this method stores that string as []byte +func (d *Decimal) UnmarshalBinary(data []byte) error { + // Extract the exponent + d.exp = int32(binary.BigEndian.Uint32(data[:4])) + + // Extract the value + d.value = new(big.Int) + return d.value.GobDecode(data[4:]) +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (d Decimal) MarshalBinary() (data []byte, err error) { + // Write the exponent first since it's a fixed size + v1 := make([]byte, 4) + binary.BigEndian.PutUint32(v1, uint32(d.exp)) + + // Add the value + var v2 []byte + if v2, err = d.value.GobEncode(); err != nil { + return + } + + // Return the byte array + data = append(v1, v2...) + return +} + +// Scan implements the sql.Scanner interface for database deserialization. +func (d *Decimal) Scan(value interface{}) error { + // first try to see if the data is stored in database as a Numeric datatype + switch v := value.(type) { + + case float32: + *d = NewFromFloat(float64(v)) + return nil + + case float64: + // numeric in sqlite3 sends us float64 + *d = NewFromFloat(v) + return nil + + case int64: + // at least in sqlite3 when the value is 0 in db, the data is sent + // to us as an int64 instead of a float64 ... + *d = New(v, 0) + return nil + + default: + // default is trying to interpret value stored as string + str, err := unquoteIfQuoted(v) + if err != nil { + return err + } + *d, err = NewFromString(str) + return err + } +} + +// Value implements the driver.Valuer interface for database serialization. +func (d Decimal) Value() (driver.Value, error) { + return d.String(), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for XML +// deserialization. +func (d *Decimal) UnmarshalText(text []byte) error { + str := string(text) + + dec, err := NewFromString(str) + *d = dec + if err != nil { + return fmt.Errorf("Error decoding string '%s': %s", str, err) + } + + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface for XML +// serialization. +func (d Decimal) MarshalText() (text []byte, err error) { + return []byte(d.String()), nil +} + +// GobEncode implements the gob.GobEncoder interface for gob serialization. +func (d Decimal) GobEncode() ([]byte, error) { + return d.MarshalBinary() +} + +// GobDecode implements the gob.GobDecoder interface for gob serialization. +func (d *Decimal) GobDecode(data []byte) error { + return d.UnmarshalBinary(data) +} + +// StringScaled first scales the decimal then calls .String() on it. +// NOTE: buggy, unintuitive, and DEPRECATED! Use StringFixed instead. +func (d Decimal) StringScaled(exp int32) string { + return d.rescale(exp).String() +} + +func (d Decimal) string(trimTrailingZeros bool) string { + if d.exp >= 0 { + return d.rescale(0).value.String() + } + + abs := new(big.Int).Abs(d.value) + str := abs.String() + + var intPart, fractionalPart string + + // NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN + // and you are on a 32-bit machine. Won't fix this super-edge case. + dExpInt := int(d.exp) + if len(str) > -dExpInt { + intPart = str[:len(str)+dExpInt] + fractionalPart = str[len(str)+dExpInt:] + } else { + intPart = "0" + + num0s := -dExpInt - len(str) + fractionalPart = strings.Repeat("0", num0s) + str + } + + if trimTrailingZeros { + i := len(fractionalPart) - 1 + for ; i >= 0; i-- { + if fractionalPart[i] != '0' { + break + } + } + fractionalPart = fractionalPart[:i+1] + } + + number := intPart + if len(fractionalPart) > 0 { + number += "." + fractionalPart + } + + if d.value.Sign() < 0 { + return "-" + number + } + + return number +} + +func (d *Decimal) ensureInitialized() { + if d.value == nil { + d.value = new(big.Int) + } +} + +// Min returns the smallest Decimal that was passed in the arguments. +// +// To call this function with an array, you must do: +// +// Min(arr[0], arr[1:]...) +// +// This makes it harder to accidentally call Min with 0 arguments. +func Min(first Decimal, rest ...Decimal) Decimal { + ans := first + for _, item := range rest { + if item.Cmp(ans) < 0 { + ans = item + } + } + return ans +} + +// Max returns the largest Decimal that was passed in the arguments. +// +// To call this function with an array, you must do: +// +// Max(arr[0], arr[1:]...) +// +// This makes it harder to accidentally call Max with 0 arguments. +func Max(first Decimal, rest ...Decimal) Decimal { + ans := first + for _, item := range rest { + if item.Cmp(ans) > 0 { + ans = item + } + } + return ans +} + +// Sum returns the combined total of the provided first and rest Decimals +func Sum(first Decimal, rest ...Decimal) Decimal { + total := first + for _, item := range rest { + total = total.Add(item) + } + + return total +} + +// Avg returns the average value of the provided first and rest Decimals +func Avg(first Decimal, rest ...Decimal) Decimal { + count := New(int64(len(rest)+1), 0) + sum := Sum(first, rest...) + return sum.Div(count) +} + +func min(x, y int32) int32 { + if x >= y { + return y + } + return x +} + +func unquoteIfQuoted(value interface{}) (string, error) { + var bytes []byte + + switch v := value.(type) { + case string: + bytes = []byte(v) + case []byte: + bytes = v + default: + return "", fmt.Errorf("Could not convert value '%+v' to byte array of type '%T'", + value, value) + } + + // If the amount is quoted, strip the quotes + if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' { + bytes = bytes[1 : len(bytes)-1] + } + return string(bytes), nil +} + +// NullDecimal represents a nullable decimal with compatibility for +// scanning null values from the database. +type NullDecimal struct { + Decimal Decimal + Valid bool +} + +// Scan implements the sql.Scanner interface for database deserialization. +func (d *NullDecimal) Scan(value interface{}) error { + if value == nil { + d.Valid = false + return nil + } + d.Valid = true + return d.Decimal.Scan(value) +} + +// Value implements the driver.Valuer interface for database serialization. +func (d NullDecimal) Value() (driver.Value, error) { + if !d.Valid { + return nil, nil + } + return d.Decimal.Value() +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *NullDecimal) UnmarshalJSON(decimalBytes []byte) error { + if string(decimalBytes) == "null" { + d.Valid = false + return nil + } + d.Valid = true + return d.Decimal.UnmarshalJSON(decimalBytes) +} + +// MarshalJSON implements the json.Marshaler interface. +func (d NullDecimal) MarshalJSON() ([]byte, error) { + if !d.Valid { + return []byte("null"), nil + } + return d.Decimal.MarshalJSON() +} + +// Trig functions + +// Atan returns the arctangent, in radians, of x. +func (x Decimal) Atan() Decimal { + if x.Equal(NewFromFloat(0.0)) { + return x + } + if x.GreaterThan(NewFromFloat(0.0)) { + return x.satan() + } + return x.Neg().satan().Neg() +} + +func (d Decimal) xatan() Decimal { + P0 := NewFromFloat(-8.750608600031904122785e-01) + P1 := NewFromFloat(-1.615753718733365076637e+01) + P2 := NewFromFloat(-7.500855792314704667340e+01) + P3 := NewFromFloat(-1.228866684490136173410e+02) + P4 := NewFromFloat(-6.485021904942025371773e+01) + Q0 := NewFromFloat(2.485846490142306297962e+01) + Q1 := NewFromFloat(1.650270098316988542046e+02) + Q2 := NewFromFloat(4.328810604912902668951e+02) + Q3 := NewFromFloat(4.853903996359136964868e+02) + Q4 := NewFromFloat(1.945506571482613964425e+02) + z := d.Mul(d) + b1 := P0.Mul(z).Add(P1).Mul(z).Add(P2).Mul(z).Add(P3).Mul(z).Add(P4).Mul(z) + b2 := z.Add(Q0).Mul(z).Add(Q1).Mul(z).Add(Q2).Mul(z).Add(Q3).Mul(z).Add(Q4) + z = b1.Div(b2) + z = d.Mul(z).Add(d) + return z +} + +// satan reduces its argument (known to be positive) +// to the range [0, 0.66] and calls xatan. +func (d Decimal) satan() Decimal { + Morebits := NewFromFloat(6.123233995736765886130e-17) // pi/2 = PIO2 + Morebits + Tan3pio8 := NewFromFloat(2.41421356237309504880) // tan(3*pi/8) + pi := NewFromFloat(3.14159265358979323846264338327950288419716939937510582097494459) + + if d.LessThanOrEqual(NewFromFloat(0.66)) { + return d.xatan() + } + if d.GreaterThan(Tan3pio8) { + return pi.Div(NewFromFloat(2.0)).Sub(NewFromFloat(1.0).Div(d).xatan()).Add(Morebits) + } + return pi.Div(NewFromFloat(4.0)).Add((d.Sub(NewFromFloat(1.0)).Div(d.Add(NewFromFloat(1.0)))).xatan()).Add(NewFromFloat(0.5).Mul(Morebits)) +} + +// sin coefficients + var _sin = [...]Decimal{ + NewFromFloat(1.58962301576546568060E-10), // 0x3de5d8fd1fd19ccd + NewFromFloat(-2.50507477628578072866E-8), // 0xbe5ae5e5a9291f5d + NewFromFloat(2.75573136213857245213E-6), // 0x3ec71de3567d48a1 + NewFromFloat(-1.98412698295895385996E-4), // 0xbf2a01a019bfdf03 + NewFromFloat(8.33333333332211858878E-3), // 0x3f8111111110f7d0 + NewFromFloat(-1.66666666666666307295E-1), // 0xbfc5555555555548 + } + +// Sin returns the sine of the radian argument x. + func (d Decimal) Sin() Decimal { + PI4A := NewFromFloat(7.85398125648498535156E-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668E-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645E-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + if d.Equal(NewFromFloat(0.0)) { + return d + } + // make argument positive but save the sign + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + sign = true + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + j &= 7 // octant modulo 2Pi radians (360 degrees) + // reflect in x axis + if j > 3 { + sign = !sign + j -= 4 + } + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if j == 1 || j == 2 { + w := zz.Mul(zz).Mul(_cos[0].Mul(zz).Add(_cos[1]).Mul(zz).Add(_cos[2]).Mul(zz).Add(_cos[3]).Mul(zz).Add(_cos[4]).Mul(zz).Add(_cos[5])) + y = NewFromFloat(1.0).Sub(NewFromFloat(0.5).Mul(zz)).Add(w) + } else { + y = z.Add(z.Mul(zz).Mul(_sin[0].Mul(zz).Add(_sin[1]).Mul(zz).Add(_sin[2]).Mul(zz).Add(_sin[3]).Mul(zz).Add(_sin[4]).Mul(zz).Add(_sin[5]))) + } + if sign { + y = y.Neg() + } + return y + } + + // cos coefficients + var _cos = [...]Decimal{ + NewFromFloat(-1.13585365213876817300E-11), // 0xbda8fa49a0861a9b + NewFromFloat(2.08757008419747316778E-9), // 0x3e21ee9d7b4e3f05 + NewFromFloat(-2.75573141792967388112E-7), // 0xbe927e4f7eac4bc6 + NewFromFloat(2.48015872888517045348E-5), // 0x3efa01a019c844f5 + NewFromFloat(-1.38888888888730564116E-3), // 0xbf56c16c16c14f91 + NewFromFloat(4.16666666666665929218E-2), // 0x3fa555555555554b + } + + // Cos returns the cosine of the radian argument x. + func (d Decimal) Cos() Decimal { + + PI4A := NewFromFloat(7.85398125648498535156E-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668E-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645E-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + // make argument positive + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + j &= 7 // octant modulo 2Pi radians (360 degrees) + // reflect in x axis + if j > 3 { + sign = !sign + j -= 4 + } + if j > 1 { + sign = !sign + } + + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if j == 1 || j == 2 { + y = z.Add(z.Mul(zz).Mul(_sin[0].Mul(zz).Add(_sin[1]).Mul(zz).Add(_sin[2]).Mul(zz).Add(_sin[3]).Mul(zz).Add(_sin[4]).Mul(zz).Add(_sin[5]))) + } else { + w := zz.Mul(zz).Mul(_cos[0].Mul(zz).Add(_cos[1]).Mul(zz).Add(_cos[2]).Mul(zz).Add(_cos[3]).Mul(zz).Add(_cos[4]).Mul(zz).Add(_cos[5])) + y = NewFromFloat(1.0).Sub(NewFromFloat(0.5).Mul(zz)).Add(w) + } + if sign { + y = y.Neg() + } + return y + } + + var _tanP = [...]Decimal{ + NewFromFloat(-1.30936939181383777646E+4), // 0xc0c992d8d24f3f38 + NewFromFloat(1.15351664838587416140E+6), // 0x413199eca5fc9ddd + NewFromFloat(-1.79565251976484877988E+7), // 0xc1711fead3299176 + } + var _tanQ = [...]Decimal{ + NewFromFloat(1.00000000000000000000E+0), + NewFromFloat(1.36812963470692954678E+4), //0x40cab8a5eeb36572 + NewFromFloat(-1.32089234440210967447E+6), //0xc13427bc582abc96 + NewFromFloat(2.50083801823357915839E+7), //0x4177d98fc2ead8ef + NewFromFloat(-5.38695755929454629881E+7), //0xc189afe03cbe5a31 + } + + // Tan returns the tangent of the radian argument x. + func (d Decimal) Tan() Decimal { + + PI4A := NewFromFloat(7.85398125648498535156E-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668E-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645E-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + if d.Equal(NewFromFloat(0.0)) { + return d + } + + // make argument positive but save the sign + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + sign = true + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if zz.GreaterThan(NewFromFloat(1e-14)) { + w := zz.Mul(_tanP[0].Mul(zz).Add(_tanP[1]).Mul(zz).Add(_tanP[2])) + x := zz.Add(_tanQ[1]).Mul(zz).Add(_tanQ[2]).Mul(zz).Add(_tanQ[3]).Mul(zz).Add(_tanQ[4]) + y = z.Add(z.Mul(w.Div(x))) + } else { + y = z + } + if j&2 == 2 { + y = NewFromFloat(-1.0).Div(y) + } + if sign { + y = y.Neg() + } + return y + } diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/rounding.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/rounding.go new file mode 100644 index 0000000..fdd74ea --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/shopspring/decimal/rounding.go @@ -0,0 +1,118 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Multiprecision decimal numbers. +// For floating-point formatting only; not general purpose. +// Only operations are assign and (binary) left/right shift. +// Can do binary floating point in multiprecision decimal precisely +// because 2 divides 10; cannot do decimal floating point +// in multiprecision binary precisely. +package decimal + +type floatInfo struct { + mantbits uint + expbits uint + bias int +} + +var float32info = floatInfo{23, 8, -127} +var float64info = floatInfo{52, 11, -1023} + +// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits +// that will let the original floating point value be precisely reconstructed. +func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) { + // If mantissa is zero, the number is zero; stop now. + if mant == 0 { + d.nd = 0 + return + } + + // Compute upper and lower such that any decimal number + // between upper and lower (possibly inclusive) + // will round to the original floating point number. + + // We may see at once that the number is already shortest. + // + // Suppose d is not denormal, so that 2^exp <= d < 10^dp. + // The closest shorter number is at least 10^(dp-nd) away. + // The lower/upper bounds computed below are at distance + // at most 2^(exp-mantbits). + // + // So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits), + // or equivalently log2(10)*(dp-nd) > exp-mantbits. + // It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32). + minexp := flt.bias + 1 // minimum possible exponent + if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) { + // The number is already shortest. + return + } + + // d = mant << (exp - mantbits) + // Next highest floating point number is mant+1 << exp-mantbits. + // Our upper bound is halfway between, mant*2+1 << exp-mantbits-1. + upper := new(decimal) + upper.Assign(mant*2 + 1) + upper.Shift(exp - int(flt.mantbits) - 1) + + // d = mant << (exp - mantbits) + // Next lowest floating point number is mant-1 << exp-mantbits, + // unless mant-1 drops the significant bit and exp is not the minimum exp, + // in which case the next lowest is mant*2-1 << exp-mantbits-1. + // Either way, call it mantlo << explo-mantbits. + // Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1. + var mantlo uint64 + var explo int + if mant > 1< 0 { + h.fd.Close() + + for i := h.backupCount - 1; i > 0; i-- { + sfn := fmt.Sprintf("%s.%d", h.fileName, i) + dfn := fmt.Sprintf("%s.%d", h.fileName, i+1) + + os.Rename(sfn, dfn) + } + + dfn := fmt.Sprintf("%s.1", h.fileName) + os.Rename(h.fileName, dfn) + + h.fd, _ = os.OpenFile(h.fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + h.curBytes = 0 + f, err := h.fd.Stat() + if err != nil { + return + } + h.curBytes = int(f.Size()) + } +} + +// TimeRotatingFileHandler writes log to a file, +// it will backup current and open a new one, with a period time you sepecified. +// +// refer: http://docs.python.org/2/library/logging.handlers.html. +// same like python TimedRotatingFileHandler. +type TimeRotatingFileHandler struct { + fd *os.File + + baseName string + interval int64 + suffix string + rolloverAt int64 +} + +// TimeRotating way +const ( + WhenSecond = iota + WhenMinute + WhenHour + WhenDay +) + +// NewTimeRotatingFileHandler creates a TimeRotatingFileHandler +func NewTimeRotatingFileHandler(baseName string, when int8, interval int) (*TimeRotatingFileHandler, error) { + dir := path.Dir(baseName) + os.MkdirAll(dir, 0777) + + h := new(TimeRotatingFileHandler) + + h.baseName = baseName + + switch when { + case WhenSecond: + h.interval = 1 + h.suffix = "2006-01-02_15-04-05" + case WhenMinute: + h.interval = 60 + h.suffix = "2006-01-02_15-04" + case WhenHour: + h.interval = 3600 + h.suffix = "2006-01-02_15" + case WhenDay: + h.interval = 3600 * 24 + h.suffix = "2006-01-02" + default: + return nil, fmt.Errorf("invalid when_rotate: %d", when) + } + + h.interval = h.interval * int64(interval) + + var err error + h.fd, err = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + return nil, err + } + + fInfo, _ := h.fd.Stat() + h.rolloverAt = fInfo.ModTime().Unix() + h.interval + + return h, nil +} + +func (h *TimeRotatingFileHandler) doRollover() { + //refer http://hg.python.org/cpython/file/2.7/Lib/logging/handlers.py + now := time.Now() + + if h.rolloverAt <= now.Unix() { + fName := h.baseName + now.Format(h.suffix) + h.fd.Close() + e := os.Rename(h.baseName, fName) + if e != nil { + panic(e) + } + + h.fd, _ = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + + h.rolloverAt = time.Now().Unix() + h.interval + } +} + +// Write implements Handler interface +func (h *TimeRotatingFileHandler) Write(b []byte) (n int, err error) { + h.doRollover() + return h.fd.Write(b) +} + +// Close implements Handler interface +func (h *TimeRotatingFileHandler) Close() error { + return h.fd.Close() +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/handler.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/handler.go new file mode 100644 index 0000000..5460f06 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/handler.go @@ -0,0 +1,54 @@ +package log + +import ( + "io" +) + +//Handler writes logs to somewhere +type Handler interface { + Write(p []byte) (n int, err error) + Close() error +} + +// StreamHandler writes logs to a specified io Writer, maybe stdout, stderr, etc... +type StreamHandler struct { + w io.Writer +} + +// NewStreamHandler creates a StreamHandler +func NewStreamHandler(w io.Writer) (*StreamHandler, error) { + h := new(StreamHandler) + + h.w = w + + return h, nil +} + +// Write implements Handler interface +func (h *StreamHandler) Write(b []byte) (n int, err error) { + return h.w.Write(b) +} + +// Close implements Handler interface +func (h *StreamHandler) Close() error { + return nil +} + +// NullHandler does nothing, it discards anything. +type NullHandler struct { +} + +// NewNullHandler creates a NullHandler +func NewNullHandler() (*NullHandler, error) { + return new(NullHandler), nil +} + +// // Write implements Handler interface +func (h *NullHandler) Write(b []byte) (n int, err error) { + return len(b), nil +} + +// Close implements Handler interface +func (h *NullHandler) Close() { + +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/log.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/log.go new file mode 100644 index 0000000..956186d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/log.go @@ -0,0 +1,137 @@ +package log + +import ( + "fmt" + "os" +) + +var logger = NewDefault(newStdHandler()) + +// SetDefaultLogger changes the global logger +func SetDefaultLogger(l *Logger) { + logger = l +} + +// SetLevel changes the logger level +func SetLevel(level Level) { + logger.SetLevel(level) +} + +// SetLevelByName changes the logger level by name +func SetLevelByName(name string) { + logger.SetLevelByName(name) +} + +// Fatal records the log with fatal level and exits +func Fatal(args ...interface{}) { + logger.Output(2, LevelFatal, fmt.Sprint(args...)) + os.Exit(1) +} + +// Fatalf records the log with fatal level and exits +func Fatalf(format string, args ...interface{}) { + logger.Output(2, LevelFatal, fmt.Sprintf(format, args...)) + os.Exit(1) +} + +// Fatalln records the log with fatal level and exits +func Fatalln(args ...interface{}) { + logger.Output(2, LevelFatal, fmt.Sprintln(args...)) + os.Exit(1) +} + +// Panic records the log with fatal level and panics +func Panic(args ...interface{}) { + msg := fmt.Sprint(args...) + logger.Output(2, LevelError, msg) + panic(msg) +} + +// Panicf records the log with fatal level and panics +func Panicf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.Output(2, LevelError, msg) + panic(msg) +} + +// Panicln records the log with fatal level and panics +func Panicln(args ...interface{}) { + msg := fmt.Sprintln(args...) + logger.Output(2, LevelError, msg) + panic(msg) +} + +// Print records the log with trace level +func Print(args ...interface{}) { + logger.Output(2, LevelTrace, fmt.Sprint(args...)) +} + +// Printf records the log with trace level +func Printf(format string, args ...interface{}) { + logger.Output(2, LevelTrace, fmt.Sprintf(format, args...)) +} + +// Println records the log with trace level +func Println(args ...interface{}) { + logger.Output(2, LevelTrace, fmt.Sprintln(args...)) +} + +// Debug records the log with debug level +func Debug(args ...interface{}) { + logger.Output(2, LevelDebug, fmt.Sprint(args...)) +} + +// Debugf records the log with debug level +func Debugf(format string, args ...interface{}) { + logger.Output(2, LevelDebug, fmt.Sprintf(format, args...)) +} + +// Debugln records the log with debug level +func Debugln(args ...interface{}) { + logger.Output(2, LevelDebug, fmt.Sprintln(args...)) +} + +// Error records the log with error level +func Error(args ...interface{}) { + logger.Output(2, LevelError, fmt.Sprint(args...)) +} + +// Errorf records the log with error level +func Errorf(format string, args ...interface{}) { + logger.Output(2, LevelError, fmt.Sprintf(format, args...)) +} + +// Errorln records the log with error level +func Errorln(args ...interface{}) { + logger.Output(2, LevelError, fmt.Sprintln(args...)) +} + +// Info records the log with info level +func Info(args ...interface{}) { + logger.Output(2, LevelInfo, fmt.Sprint(args...)) +} + +// Infof records the log with info level +func Infof(format string, args ...interface{}) { + logger.Output(2, LevelInfo, fmt.Sprintf(format, args...)) +} + +// Infoln records the log with info level +func Infoln(args ...interface{}) { + logger.Output(2, LevelInfo, fmt.Sprintln(args...)) +} + +// Warn records the log with warn level +func Warn(args ...interface{}) { + logger.Output(2, LevelWarn, fmt.Sprint(args...)) +} + +// Warnf records the log with warn level +func Warnf(format string, args ...interface{}) { + logger.Output(2, LevelWarn, fmt.Sprintf(format, args...)) +} + +// Warnln records the log with warn level +func Warnln(args ...interface{}) { + logger.Output(2, LevelWarn, fmt.Sprintln(args...)) +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/logger.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/logger.go new file mode 100644 index 0000000..b2f7ed2 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/log/logger.go @@ -0,0 +1,340 @@ +package log + +import ( + "fmt" + "os" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/siddontang/go-log/loggers" +) + +const ( + timeFormat = "2006/01/02 15:04:05" + maxBufPoolSize = 16 +) + +// Logger flag +const ( + Ltime = 1 << iota // time format "2006/01/02 15:04:05" + Lfile // file.go:123 + Llevel // [Trace|Debug|Info...] +) + +// Level type +type Level int + +// Log level, from low to high, more high means more serious +const ( + LevelTrace Level = iota + LevelDebug + LevelInfo + LevelWarn + LevelError + LevelFatal +) + +// String returns level String +func (l Level) String() string { + switch l { + case LevelTrace: + return "trace" + case LevelDebug: + return "debug" + case LevelInfo: + return "info" + case LevelWarn: + return "warn" + case LevelError: + return "error" + case LevelFatal: + return "fatal" + } + // return default info + return "info" +} + +// Logger is the logger to record log +type Logger struct { + // TODO: support logger.Contextual + loggers.Advanced + + sync.Mutex + + level Level + flag int + + handler Handler + + quit chan struct{} + msg chan []byte + + bufs [][]byte +} + +// New creates a logger with specified handler and flag +func New(handler Handler, flag int) *Logger { + var l = new(Logger) + + l.level = LevelInfo + l.handler = handler + + l.flag = flag + + l.quit = make(chan struct{}) + + l.msg = make(chan []byte, 1024) + + l.bufs = make([][]byte, 0, 16) + + go l.run() + + return l +} + +// NewDefault creates default logger with specified handler and flag: Ltime|Lfile|Llevel +func NewDefault(handler Handler) *Logger { + return New(handler, Ltime|Lfile|Llevel) +} + +func newStdHandler() *StreamHandler { + h, _ := NewStreamHandler(os.Stdout) + return h +} + +func (l *Logger) run() { + for { + select { + case msg := <-l.msg: + l.handler.Write(msg) + l.putBuf(msg) + case <-l.quit: + l.handler.Close() + } + } +} + +func (l *Logger) popBuf() []byte { + l.Lock() + var buf []byte + if len(l.bufs) == 0 { + buf = make([]byte, 0, 1024) + } else { + buf = l.bufs[len(l.bufs)-1] + l.bufs = l.bufs[0 : len(l.bufs)-1] + } + l.Unlock() + + return buf +} + +func (l *Logger) putBuf(buf []byte) { + l.Lock() + if len(l.bufs) < maxBufPoolSize { + buf = buf[0:0] + l.bufs = append(l.bufs, buf) + } + l.Unlock() +} + +// Close closes the logger +func (l *Logger) Close() { + if l.quit == nil { + return + } + + close(l.quit) + l.quit = nil +} + +// SetLevel sets log level, any log level less than it will not log +func (l *Logger) SetLevel(level Level) { + l.level = level +} + +// SetLevelByName sets log level by name +func (l *Logger) SetLevelByName(name string) { + level := LevelInfo + switch strings.ToLower(name) { + case "trace": + level = LevelTrace + case "debug": + level = LevelDebug + case "warn", "warning": + level = LevelWarn + case "error": + level = LevelError + case "fatal": + level = LevelFatal + default: + level = LevelInfo + } + + l.SetLevel(level) +} + +// Output records the log with special callstack depth and log level. +func (l *Logger) Output(callDepth int, level Level, msg string) { + if l.level > level { + return + } + + buf := l.popBuf() + + if l.flag&Ltime > 0 { + now := time.Now().Format(timeFormat) + buf = append(buf, '[') + buf = append(buf, now...) + buf = append(buf, "] "...) + } + + if l.flag&Llevel > 0 { + buf = append(buf, '[') + buf = append(buf, level.String()...) + buf = append(buf, "] "...) + } + + if l.flag&Lfile > 0 { + _, file, line, ok := runtime.Caller(callDepth) + if !ok { + file = "???" + line = 0 + } else { + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + file = file[i+1:] + break + } + } + } + + buf = append(buf, file...) + buf = append(buf, ':') + + buf = strconv.AppendInt(buf, int64(line), 10) + buf = append(buf, ' ') + } + + buf = append(buf, msg...) + if len(msg) == 0 || msg[len(msg)-1] != '\n' { + buf = append(buf, '\n') + } + l.msg <- buf +} + +// Fatal records the log with fatal level and exits +func (l *Logger) Fatal(args ...interface{}) { + l.Output(2, LevelFatal, fmt.Sprint(args...)) + os.Exit(1) +} + +// Fatalf records the log with fatal level and exits +func (l *Logger) Fatalf(format string, args ...interface{}) { + l.Output(2, LevelFatal, fmt.Sprintf(format, args...)) + os.Exit(1) +} + +// Fatalln records the log with fatal level and exits +func (l *Logger) Fatalln(args ...interface{}) { + l.Output(2, LevelFatal, fmt.Sprintln(args...)) + os.Exit(1) +} + +// Panic records the log with fatal level and panics +func (l *Logger) Panic(args ...interface{}) { + msg := fmt.Sprint(args...) + l.Output(2, LevelError, msg) + panic(msg) +} + +// Panicf records the log with fatal level and panics +func (l *Logger) Panicf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + l.Output(2, LevelError, msg) + panic(msg) +} + +// Panicln records the log with fatal level and panics +func (l *Logger) Panicln(args ...interface{}) { + msg := fmt.Sprintln(args...) + l.Output(2, LevelError, msg) + panic(msg) +} + +// Print records the log with trace level +func (l *Logger) Print(args ...interface{}) { + l.Output(2, LevelTrace, fmt.Sprint(args...)) +} + +// Printf records the log with trace level +func (l *Logger) Printf(format string, args ...interface{}) { + l.Output(2, LevelTrace, fmt.Sprintf(format, args...)) +} + +// Println records the log with trace level +func (l *Logger) Println(args ...interface{}) { + l.Output(2, LevelTrace, fmt.Sprintln(args...)) +} + +// Debug records the log with debug level +func (l *Logger) Debug(args ...interface{}) { + l.Output(2, LevelDebug, fmt.Sprint(args...)) +} + +// Debugf records the log with debug level +func (l *Logger) Debugf(format string, args ...interface{}) { + l.Output(2, LevelDebug, fmt.Sprintf(format, args...)) +} + +// Debugln records the log with debug level +func (l *Logger) Debugln(args ...interface{}) { + l.Output(2, LevelDebug, fmt.Sprintln(args...)) +} + +// Error records the log with error level +func (l *Logger) Error(args ...interface{}) { + l.Output(2, LevelError, fmt.Sprint(args...)) +} + +// Errorf records the log with error level +func (l *Logger) Errorf(format string, args ...interface{}) { + l.Output(2, LevelError, fmt.Sprintf(format, args...)) +} + +// Errorln records the log with error level +func (l *Logger) Errorln(args ...interface{}) { + l.Output(2, LevelError, fmt.Sprintln(args...)) +} + +// Info records the log with info level +func (l *Logger) Info(args ...interface{}) { + l.Output(2, LevelInfo, fmt.Sprint(args...)) +} + +// Infof records the log with info level +func (l *Logger) Infof(format string, args ...interface{}) { + l.Output(2, LevelInfo, fmt.Sprintf(format, args...)) +} + +// Infoln records the log with info level +func (l *Logger) Infoln(args ...interface{}) { + l.Output(2, LevelInfo, fmt.Sprintln(args...)) +} + +// Warn records the log with warn level +func (l *Logger) Warn(args ...interface{}) { + l.Output(2, LevelWarn, fmt.Sprint(args...)) +} + +// Warnf records the log with warn level +func (l *Logger) Warnf(format string, args ...interface{}) { + l.Output(2, LevelWarn, fmt.Sprintf(format, args...)) +} + +// Warnln records the log with warn level +func (l *Logger) Warnln(args ...interface{}) { + l.Output(2, LevelWarn, fmt.Sprintln(args...)) +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/loggers/loggers.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/loggers/loggers.go new file mode 100644 index 0000000..2723b24 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go-log/loggers/loggers.go @@ -0,0 +1,68 @@ +// MIT License + +// Copyright (c) 2017 Birkir A. Barkarson + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package loggers + +// Standard is the interface used by Go's standard library's log package. +type Standard interface { + Fatal(args ...interface{}) + Fatalf(format string, args ...interface{}) + Fatalln(args ...interface{}) + + Panic(args ...interface{}) + Panicf(format string, args ...interface{}) + Panicln(args ...interface{}) + + Print(args ...interface{}) + Printf(format string, args ...interface{}) + Println(args ...interface{}) +} + +// Advanced is an interface with commonly used log level methods. +type Advanced interface { + Standard + + Debug(args ...interface{}) + Debugf(format string, args ...interface{}) + Debugln(args ...interface{}) + + Error(args ...interface{}) + Errorf(format string, args ...interface{}) + Errorln(args ...interface{}) + + Info(args ...interface{}) + Infof(format string, args ...interface{}) + Infoln(args ...interface{}) + + Warn(args ...interface{}) + Warnf(format string, args ...interface{}) + Warnln(args ...interface{}) +} + +// Contextual is an interface that allows context addition to a log statement before +// calling the final print (message/level) method. +type Contextual interface { + Advanced + + WithField(key string, value interface{}) Advanced + WithFields(fields ...interface{}) Advanced +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/LICENSE similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/LICENSE rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/LICENSE diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/bson/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/bson/LICENSE new file mode 100644 index 0000000..8903260 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/bson/LICENSE @@ -0,0 +1,25 @@ +BSON library for Go + +Copyright (c) 2010-2012 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/filelock/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/filelock/LICENSE new file mode 100644 index 0000000..fec05ce --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/filelock/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2011 The LevelDB-Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/hack/hack.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/hack/hack.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/hack/hack.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/hack/hack.go diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/snappy/LICENSE similarity index 95% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/LICENSE rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/snappy/LICENSE index 6a66aea..6050c10 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/LICENSE +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/snappy/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. +Copyright (c) 2011 The Snappy-Go Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/atomic.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/sync2/atomic.go similarity index 99% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/atomic.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/sync2/atomic.go index 424a974..382fc20 100644 --- a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/atomic.go +++ b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/sync2/atomic.go @@ -121,7 +121,7 @@ func (s *AtomicString) Get() string { return str } -func (s *AtomicString) CompareAndSwap(oldval, newval string) (swqpped bool) { +func (s *AtomicString) CompareAndSwap(oldval, newval string) (swapped bool) { s.mu.Lock() defer s.mu.Unlock() if s.str == oldval { diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/semaphore.go b/vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/sync2/semaphore.go similarity index 100% rename from vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/semaphore.go rename to vendor/github.com/siddontang/go-mysql/vendor/github.com/siddontang/go/sync2/semaphore.go diff --git a/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/LICENSE b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql.go b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql.go new file mode 100644 index 0000000..7b27e6b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql.go @@ -0,0 +1,62 @@ +// Copyright 2013 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +/* +Package cloudsql exposes access to Google Cloud SQL databases. + +This package does not work in App Engine "flexible environment". + +This package is intended for MySQL drivers to make App Engine-specific +connections. Applications should use this package through database/sql: +Select a pure Go MySQL driver that supports this package, and use sql.Open +with protocol "cloudsql" and an address of the Cloud SQL instance. + +A Go MySQL driver that has been tested to work well with Cloud SQL +is the go-sql-driver: + import "database/sql" + import _ "github.com/go-sql-driver/mysql" + + db, err := sql.Open("mysql", "user@cloudsql(project-id:instance-name)/dbname") + + +Another driver that works well with Cloud SQL is the mymysql driver: + import "database/sql" + import _ "github.com/ziutek/mymysql/godrv" + + db, err := sql.Open("mymysql", "cloudsql:instance-name*dbname/user/password") + + +Using either of these drivers, you can perform a standard SQL query. +This example assumes there is a table named 'users' with +columns 'first_name' and 'last_name': + + rows, err := db.Query("SELECT first_name, last_name FROM users") + if err != nil { + log.Errorf(ctx, "db.Query: %v", err) + } + defer rows.Close() + + for rows.Next() { + var firstName string + var lastName string + if err := rows.Scan(&firstName, &lastName); err != nil { + log.Errorf(ctx, "rows.Scan: %v", err) + continue + } + log.Infof(ctx, "First: %v - Last: %v", firstName, lastName) + } + if err := rows.Err(); err != nil { + log.Errorf(ctx, "Row error: %v", err) + } +*/ +package cloudsql + +import ( + "net" +) + +// Dial connects to the named Cloud SQL instance. +func Dial(instance string) (net.Conn, error) { + return connect(instance) +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go new file mode 100644 index 0000000..af62dba --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go @@ -0,0 +1,17 @@ +// Copyright 2013 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// +build appengine + +package cloudsql + +import ( + "net" + + "appengine/cloudsql" +) + +func connect(instance string) (net.Conn, error) { + return cloudsql.Dial(instance) +} diff --git a/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go new file mode 100644 index 0000000..90fa7b3 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go @@ -0,0 +1,16 @@ +// Copyright 2013 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// +build !appengine + +package cloudsql + +import ( + "errors" + "net" +) + +func connect(instance string) (net.Conn, error) { + return nil, errors.New(`cloudsql: not supported in App Engine "flexible environment"`) +}