From 3986696813f6007cac8e5d1449424014560ad6df Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 19 Oct 2016 11:57:37 +0200 Subject: [PATCH 001/104] testing enum as part of PK reference: https://github.com/github/gh-ost/issues/273 enum values are intentionally not alphabetically monotonic --- localtests/enum-pk/create.sql | 29 +++++++++++++++++++++++++++++ localtests/enum-pk/extra_args | 1 + 2 files changed, 30 insertions(+) create mode 100644 localtests/enum-pk/create.sql create mode 100644 localtests/enum-pk/extra_args diff --git a/localtests/enum-pk/create.sql b/localtests/enum-pk/create.sql new file mode 100644 index 0000000..4ba7743 --- /dev/null +++ b/localtests/enum-pk/create.sql @@ -0,0 +1,29 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + e enum('red', 'green', 'blue', 'orange') null default null collate 'utf8_bin', + primary key(id, e) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11, 'red'); + set @last_insert_id := last_insert_id(); + insert into gh_ost_test values (@last_insert_id, 11, 'green'); + insert into gh_ost_test values (null, 13, 'green'); + insert into gh_ost_test values (null, 17, 'blue'); + set @last_insert_id := last_insert_id(); + update gh_ost_test set e='orange' where id = @last_insert_id; + insert into gh_ost_test values (null, 23, null); + set @last_insert_id := last_insert_id(); + update gh_ost_test set i=i+1, e=null where id = @last_insert_id; +end ;; diff --git a/localtests/enum-pk/extra_args b/localtests/enum-pk/extra_args new file mode 100644 index 0000000..f369a56 --- /dev/null +++ b/localtests/enum-pk/extra_args @@ -0,0 +1 @@ +--alter="change e e enum('red', 'green', 'blue', 'orange', 'yellow') null default null collate 'utf8_bin'" From 25166e33c7236a6d1d7a60fceec2e2b78f85e8f1 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 19 Oct 2016 15:22:29 +0200 Subject: [PATCH 002/104] solving the enum-as-part-of-pk bug --- build.sh | 2 +- go/logic/applier.go | 8 ++-- go/logic/inspect.go | 7 +++- go/sql/builder.go | 69 ++++++++++++++++++++--------------- go/sql/builder_test.go | 14 +++---- go/sql/types.go | 1 + localtests/enum-pk/extra_args | 1 - 7 files changed, 58 insertions(+), 44 deletions(-) delete mode 100644 localtests/enum-pk/extra_args diff --git a/build.sh b/build.sh index 947a68d..b930269 100755 --- a/build.sh +++ b/build.sh @@ -2,7 +2,7 @@ # # -RELEASE_VERSION="1.0.23" +RELEASE_VERSION="1.0.24" function build { osname=$1 diff --git a/go/logic/applier.go b/go/logic/applier.go index eb23af8..07aa436 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -315,7 +315,7 @@ func (this *Applier) ExecuteThrottleQuery() (int64, error) { // ReadMigrationMinValues returns the minimum values to be iterated on rowcopy func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names()) + query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &uniqueKey.Columns) if err != nil { return err } @@ -336,7 +336,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { // ReadMigrationMaxValues returns the maximum values to be iterated on rowcopy func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names()) + query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &uniqueKey.Columns) if err != nil { return err } @@ -377,7 +377,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, - this.migrationContext.UniqueKey.Columns.Names(), + &this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(), atomic.LoadInt64(&this.migrationContext.ChunkSize), @@ -419,7 +419,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.SharedColumns.Names(), this.migrationContext.MappedSharedColumns.Names(), this.migrationContext.UniqueKey.Name, - this.migrationContext.UniqueKey.Columns.Names(), + &this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), this.migrationContext.GetIteration() == 0, diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 1b6f7f0..441afec 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -135,7 +135,7 @@ func (this *Inspector) InspectOriginalAndGhostTables() (err error) { // This additional step looks at which columns are unsigned. We could have merged this within // the `getTableColumns()` function, but it's a later patch and introduces some complexity; I feel // comfortable in doing this as a separate step. - this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns) + this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns) this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.GhostTableColumns, this.migrationContext.MappedSharedColumns) for i := range this.migrationContext.SharedColumns.Columns() { @@ -532,6 +532,11 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL columnsList.GetColumn(columnName).Type = sql.DateTimeColumnType } } + if strings.HasPrefix(columnType, "enum") { + for _, columnsList := range columnsLists { + columnsList.GetColumn(columnName).Type = sql.EnumColumnValue + } + } if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { for _, columnsList := range columnsLists { columnsList.SetCharset(columnName, charset) diff --git a/go/sql/builder.go b/go/sql/builder.go index ba587c2..c455992 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -170,12 +170,12 @@ func BuildRangeComparison(columns []string, values []string, args []interface{}, return result, explodedArgs, nil } -func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { - values := buildPreparedValues(len(columns)) - return BuildRangeComparison(columns, values, args, comparisonSign) +func BuildRangePreparedComparison(columns *ColumnList, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { + values := buildColumnsPreparedValues(columns) + return BuildRangeComparison(columns.Names(), values, args, comparisonSign) } -func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { +func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { if len(sharedColumns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery") } @@ -200,12 +200,12 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin if includeRangeStartValues { minRangeComparisonSign = GreaterThanOrEqualsComparisonSign } - rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, rangeStartArgs, minRangeComparisonSign) + rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeStartValues, rangeStartArgs, minRangeComparisonSign) if err != nil { return "", explodedArgs, err } explodedArgs = append(explodedArgs, rangeExplodedArgs...) - rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, rangeEndArgs, LessThanOrEqualsComparisonSign) + rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeEndValues, rangeEndArgs, LessThanOrEqualsComparisonSign) if err != nil { return "", explodedArgs, err } @@ -225,14 +225,14 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin return result, explodedArgs, nil } -func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { - rangeStartValues := buildPreparedValues(len(uniqueKeyColumns)) - rangeEndValues := buildPreparedValues(len(uniqueKeyColumns)) +func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { + rangeStartValues := buildColumnsPreparedValues(uniqueKeyColumns) + rangeEndValues := buildColumnsPreparedValues(uniqueKeyColumns) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) } -func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) { - if len(uniqueKeyColumns) == 0 { +func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) { + if uniqueKeyColumns.Len() == 0 { return "", explodedArgs, fmt.Errorf("Got 0 columns in BuildUniqueKeyRangeEndPreparedQuery") } databaseName = EscapeName(databaseName) @@ -253,13 +253,18 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK } explodedArgs = append(explodedArgs, rangeExplodedArgs...) - uniqueKeyColumns = duplicateNames(uniqueKeyColumns) - uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - for i := range uniqueKeyColumns { - uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) - uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumns[i]) - uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumns[i]) + uniqueKeyColumnNames := duplicateNames(uniqueKeyColumns.Names()) + uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) + uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) + for i, column := range uniqueKeyColumns.Columns() { + uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i]) + if column.Type == EnumColumnValue { + uniqueKeyColumnAscending[i] = fmt.Sprintf("concat(%s) asc", uniqueKeyColumnNames[i]) + uniqueKeyColumnDescending[i] = fmt.Sprintf("concat(%s) desc", uniqueKeyColumnNames[i]) + } else { + uniqueKeyColumnAscending[i] = fmt.Sprintf("%s asc", uniqueKeyColumnNames[i]) + uniqueKeyColumnDescending[i] = fmt.Sprintf("%s desc", uniqueKeyColumnNames[i]) + } } result = fmt.Sprintf(` select /* gh-ost %s.%s %s */ %s @@ -276,8 +281,8 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK order by %s limit 1 - `, databaseName, tableName, hint, strings.Join(uniqueKeyColumns, ", "), - strings.Join(uniqueKeyColumns, ", "), databaseName, tableName, + `, databaseName, tableName, hint, strings.Join(uniqueKeyColumnNames, ", "), + strings.Join(uniqueKeyColumnNames, ", "), databaseName, tableName, rangeStartComparison, rangeEndComparison, strings.Join(uniqueKeyColumnAscending, ", "), chunkSize, strings.Join(uniqueKeyColumnDescending, ", "), @@ -285,26 +290,30 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK return result, explodedArgs, nil } -func BuildUniqueKeyMinValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) { +func BuildUniqueKeyMinValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList) (string, error) { return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "asc") } -func BuildUniqueKeyMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string) (string, error) { +func BuildUniqueKeyMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList) (string, error) { return buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName, uniqueKeyColumns, "desc") } -func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns []string, order string) (string, error) { - if len(uniqueKeyColumns) == 0 { +func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uniqueKeyColumns *ColumnList, order string) (string, error) { + if uniqueKeyColumns.Len() == 0 { return "", fmt.Errorf("Got 0 columns in BuildUniqueKeyMinMaxValuesPreparedQuery") } databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) - uniqueKeyColumns = duplicateNames(uniqueKeyColumns) - uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - for i := range uniqueKeyColumns { - uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) - uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumns[i], order) + uniqueKeyColumnNames := duplicateNames(uniqueKeyColumns.Names()) + uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) + for i, column := range uniqueKeyColumns.Columns() { + uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i]) + if column.Type == EnumColumnValue { + uniqueKeyColumnOrder[i] = fmt.Sprintf("concat(%s) %s", uniqueKeyColumnNames[i], order) + } else { + uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumnNames[i], order) + } } query := fmt.Sprintf(` select /* gh-ost %s.%s */ %s @@ -313,7 +322,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni order by %s limit 1 - `, databaseName, tableName, strings.Join(uniqueKeyColumns, ", "), + `, databaseName, tableName, strings.Join(uniqueKeyColumnNames, ", "), databaseName, tableName, strings.Join(uniqueKeyColumnOrder, ", "), ) diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index 806b77b..46c44e1 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -166,7 +166,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { sharedColumns := []string{"id", "name", "position"} { uniqueKey := "PRIMARY" - uniqueKeyColumns := []string{"id"} + uniqueKeyColumns := NewColumnList([]string{"id"}) rangeStartValues := []string{"@v1s"} rangeEndValues := []string{"@v1e"} rangeStartArgs := []interface{}{3} @@ -185,7 +185,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { } { uniqueKey := "name_position_uidx" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartValues := []string{"@v1s", "@v2s"} rangeEndValues := []string{"@v1e", "@v2e"} rangeStartArgs := []interface{}{3, 17} @@ -212,7 +212,7 @@ func TestBuildRangeInsertQueryRenameMap(t *testing.T) { mappedSharedColumns := []string{"id", "name", "location"} { uniqueKey := "PRIMARY" - uniqueKeyColumns := []string{"id"} + uniqueKeyColumns := NewColumnList([]string{"id"}) rangeStartValues := []string{"@v1s"} rangeEndValues := []string{"@v1e"} rangeStartArgs := []interface{}{3} @@ -231,7 +231,7 @@ func TestBuildRangeInsertQueryRenameMap(t *testing.T) { } { uniqueKey := "name_position_uidx" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartValues := []string{"@v1s", "@v2s"} rangeEndValues := []string{"@v1e", "@v2e"} rangeStartArgs := []interface{}{3, 17} @@ -257,7 +257,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { sharedColumns := []string{"id", "name", "position"} { uniqueKey := "name_position_uidx" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} @@ -279,7 +279,7 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { originalTableName := "tbl" var chunkSize int64 = 500 { - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} @@ -309,7 +309,7 @@ func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) { databaseName := "mydb" originalTableName := "tbl" - uniqueKeyColumns := []string{"name", "position"} + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) { query, err := BuildUniqueKeyMinValuesPreparedQuery(databaseName, originalTableName, uniqueKeyColumns) test.S(t).ExpectNil(err) diff --git a/go/sql/types.go b/go/sql/types.go index 1c57fbb..720f92f 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -18,6 +18,7 @@ const ( UnknownColumnType ColumnType = iota TimestampColumnType = iota DateTimeColumnType = iota + EnumColumnValue = iota ) type TimezoneConvertion struct { diff --git a/localtests/enum-pk/extra_args b/localtests/enum-pk/extra_args deleted file mode 100644 index f369a56..0000000 --- a/localtests/enum-pk/extra_args +++ /dev/null @@ -1 +0,0 @@ ---alter="change e e enum('red', 'green', 'blue', 'orange', 'yellow') null default null collate 'utf8_bin'" From adad4b6c34f3078a4acebd7b3b4ed3a98caafbe7 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 20 Oct 2016 10:07:46 +0200 Subject: [PATCH 003/104] added enum limitation documentation --- doc/requirements-and-limitations.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/requirements-and-limitations.md b/doc/requirements-and-limitations.md index ae577f2..9129d42 100644 --- a/doc/requirements-and-limitations.md +++ b/doc/requirements-and-limitations.md @@ -32,3 +32,4 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - We began working towards removing this limitation. See tracking issue: https://github.com/github/gh-ost/issues/163 - Multisource is not supported when migrating via replica. It _should_ work (but never tested) when connecting directly to master (`--allow-on-master`) - Master-master setup is only supported in active-passive setup. Active-active (where table is being written to on both masters concurrently) is unsupported. It may be supported in the future. +- If you have en `enum` field as part of your migration key (typically the `PRIMARY KEY`), migration performance will be degraded and potentially bad. [Read more](https://github.com/github/gh-ost/pull/277#issuecomment-254811520) From bf92eec2145d0448bca3c7e1024f76e2fdc25847 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 20 Oct 2016 11:29:30 +0200 Subject: [PATCH 004/104] validating table structure on applier and migrator - reading column list on applier - comparing original table on applier and migrator, expecting exact column list - or else bailing out --- build.sh | 2 +- go/base/context.go | 1 + go/logic/applier.go | 13 +++++++++++++ go/logic/inspect.go | 38 ++++++++++---------------------------- go/logic/migrator.go | 21 +++++++++++---------- go/mysql/utils.go | 27 +++++++++++++++++++++++++++ 6 files changed, 63 insertions(+), 39 deletions(-) diff --git a/build.sh b/build.sh index 947a68d..c12b6a1 100755 --- a/build.sh +++ b/build.sh @@ -2,7 +2,7 @@ # # -RELEASE_VERSION="1.0.23" +RELEASE_VERSION="1.0.26" function build { osname=$1 diff --git a/go/base/context.go b/go/base/context.go index d291717..fe201e4 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -147,6 +147,7 @@ type MigrationContext struct { UserCommandedUnpostponeFlag int64 PanicAbort chan error + OriginalTableColumnsOnApplier *sql.ColumnList OriginalTableColumns *sql.ColumnList OriginalTableUniqueKeys [](*sql.UniqueKey) GhostTableColumns *sql.ColumnList diff --git a/go/logic/applier.go b/go/logic/applier.go index eb23af8..99f7098 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -67,6 +67,9 @@ func (this *Applier) InitDBConnections() (err error) { } else { this.connectionConfig.ImpliedKey = impliedKey } + if err := this.readTableColumns(); err != nil { + return err + } return nil } @@ -95,6 +98,16 @@ func (this *Applier) validateAndReadTimeZone() error { return nil } +// readTableColumns reads table columns on applier +func (this *Applier) readTableColumns() (err error) { + log.Infof("Examining table structure on applier") + this.migrationContext.OriginalTableColumnsOnApplier, err = mysql.GetTableColumns(this.db, this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName) + if err != nil { + return err + } + return nil +} + // showTableStatus returns the output of `show table status like '...'` command func (this *Applier) showTableStatus(tableName string) (rowMap sqlutils.RowMap) { rowMap = nil diff --git a/go/logic/inspect.go b/go/logic/inspect.go index ceb5c16..79a9c91 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -8,6 +8,7 @@ package logic import ( gosql "database/sql" "fmt" + "reflect" "strings" "sync/atomic" @@ -83,7 +84,7 @@ func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (colum if len(uniqueKeys) == 0 { return columns, uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") } - columns, err = this.getTableColumns(this.migrationContext.DatabaseName, tableName) + columns, err = mysql.GetTableColumns(this.db, this.migrationContext.DatabaseName, tableName) if err != nil { return columns, uniqueKeys, err } @@ -99,9 +100,15 @@ func (this *Inspector) InspectOriginalTable() (err error) { return nil } -// InspectOriginalAndGhostTables compares original and ghost tables to see whether the migration +// inspectOriginalAndGhostTables compares original and ghost tables to see whether the migration // makes sense and is valid. It extracts the list of shared columns and the chosen migration unique key -func (this *Inspector) InspectOriginalAndGhostTables() (err error) { +func (this *Inspector) inspectOriginalAndGhostTables() (err error) { + originalNamesOnApplier := this.migrationContext.OriginalTableColumnsOnApplier.Names() + originalNames := this.migrationContext.OriginalTableColumns.Names() + if !reflect.DeepEqual(originalNames, originalNamesOnApplier) { + return fmt.Errorf("It seems like table structure is not identical between master and replica. This scenario is not supported.") + } + this.migrationContext.GhostTableColumns, this.migrationContext.GhostTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.GetGhostTableName()) if err != nil { return err @@ -478,31 +485,6 @@ func (this *Inspector) CountTableRows() error { return nil } -// getTableColumns reads column list from given table -func (this *Inspector) getTableColumns(databaseName, tableName string) (*sql.ColumnList, error) { - query := fmt.Sprintf(` - show columns from %s.%s - `, - sql.EscapeName(databaseName), - sql.EscapeName(tableName), - ) - columnNames := []string{} - err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { - columnNames = append(columnNames, rowMap.GetString("Field")) - return nil - }) - if err != nil { - return nil, err - } - if len(columnNames) == 0 { - return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out", - sql.EscapeName(databaseName), - sql.EscapeName(tableName), - ) - } - return sql.NewColumnList(columnNames), nil -} - // applyColumnTypes func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsLists ...*sql.ColumnList) error { query := ` diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 7c96871..b0a6246 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -26,7 +26,7 @@ import ( type ChangelogState string const ( - TablesInPlace ChangelogState = "TablesInPlace" + GhostTableMigrated ChangelogState = "GhostTableMigrated" AllEventsUpToLockProcessed = "AllEventsUpToLockProcessed" ) @@ -58,7 +58,7 @@ type Migrator struct { migrationContext *base.MigrationContext firstThrottlingCollected chan bool - tablesInPlace chan bool + ghostTableMigrated chan bool rowCopyComplete chan bool allEventsUpToLockProcessed chan bool @@ -76,7 +76,7 @@ func NewMigrator() *Migrator { migrator := &Migrator{ migrationContext: base.GetMigrationContext(), parser: sql.NewParser(), - tablesInPlace: make(chan bool), + ghostTableMigrated: make(chan bool), firstThrottlingCollected: make(chan bool, 1), rowCopyComplete: make(chan bool), allEventsUpToLockProcessed: make(chan bool), @@ -182,9 +182,9 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er } changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) switch changelogState { - case TablesInPlace: + case GhostTableMigrated: { - this.tablesInPlace <- true + this.ghostTableMigrated <- true } case AllEventsUpToLockProcessed: { @@ -291,14 +291,14 @@ func (this *Migrator) Migrate() (err error) { return err } - log.Infof("Waiting for tables to be in place") - <-this.tablesInPlace - log.Debugf("Tables are in place") + log.Infof("Waiting for ghost table to be migrated") + <-this.ghostTableMigrated + log.Debugf("ghost table migrated") // Yay! We now know the Ghost and Changelog tables are good to examine! // When running on replica, this means the replica has those tables. When running // on master this is always true, of course, and yet it also implies this knowledge // is in the binlogs. - if err := this.inspector.InspectOriginalAndGhostTables(); err != nil { + if err := this.inspector.inspectOriginalAndGhostTables(); err != nil { return err } // Validation complete! We're good to execute this migration @@ -926,12 +926,13 @@ func (this *Migrator) initiateApplier() error { log.Errorf("Unable to create ghost table, see further error details. Perhaps a previous migration failed without dropping the table? Bailing out") return err } + if err := this.applier.AlterGhost(); err != nil { log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out") return err } - this.applier.WriteChangelogState(string(TablesInPlace)) + this.applier.WriteChangelogState(string(GhostTableMigrated)) go this.applier.InitiateHeartbeat() return nil } diff --git a/go/mysql/utils.go b/go/mysql/utils.go index cf1ce54..ae22e3a 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -10,6 +10,8 @@ import ( "fmt" "time" + "github.com/github/gh-ost/go/sql" + "github.com/outbrain/golib/log" "github.com/outbrain/golib/sqlutils" ) @@ -149,3 +151,28 @@ func GetInstanceKey(db *gosql.DB) (instanceKey *InstanceKey, err error) { err = db.QueryRow(`select @@global.hostname, @@global.port`).Scan(&instanceKey.Hostname, &instanceKey.Port) return instanceKey, err } + +// GetTableColumns reads column list from given table +func GetTableColumns(db *gosql.DB, databaseName, tableName string) (*sql.ColumnList, error) { + query := fmt.Sprintf(` + show columns from %s.%s + `, + sql.EscapeName(databaseName), + sql.EscapeName(tableName), + ) + columnNames := []string{} + err := sqlutils.QueryRowsMap(db, query, func(rowMap sqlutils.RowMap) error { + columnNames = append(columnNames, rowMap.GetString("Field")) + return nil + }) + if err != nil { + return nil, err + } + if len(columnNames) == 0 { + return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out", + sql.EscapeName(databaseName), + sql.EscapeName(tableName), + ) + } + return sql.NewColumnList(columnNames), nil +} From b696106cb6994e885aa5d2cf545d8887727a8f04 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 20 Oct 2016 15:05:47 +0200 Subject: [PATCH 005/104] fixing bug introduced for charset and timezone tests --- go/logic/inspect.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index cf830f1..519c5d3 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -135,7 +135,8 @@ func (this *Inspector) InspectOriginalAndGhostTables() (err error) { // This additional step looks at which columns are unsigned. We could have merged this within // the `getTableColumns()` function, but it's a later patch and introduces some complexity; I feel // comfortable in doing this as a separate step. - this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns) + this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns) + this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &this.migrationContext.UniqueKey.Columns) this.applyColumnTypes(this.migrationContext.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.GhostTableColumns, this.migrationContext.MappedSharedColumns) for i := range this.migrationContext.SharedColumns.Columns() { From 7fe7b032e96c43bdcdb8af4c660bb9688c53642f Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 26 Oct 2016 10:40:12 +0200 Subject: [PATCH 006/104] concurrent-rowcount defaults 'true' --- doc/command-line-flags.md | 3 ++- go/cmd/gh-ost/main.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index 404703b..d51522c 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -71,7 +71,8 @@ A `gh-ost` execution need to copy whatever rows you have in your existing table `gh-ost` also supports the `--exact-rowcount` flag. When this flag is given, two things happen: - An initial, authoritative `select count(*) from your_table`. This query may take a long time to complete, but is performed before we begin the massive operations. - When `--concurrent-rowcount` is also specified, this runs in paralell to row copy. + When `--concurrent-rowcount` is also specified, this runs in parallel to row copy. + Note: `--concurrent-rowcount` now defaults to `true`. - A continuous update to the estimate as we make progress applying events. We heuristically update the number of rows based on the queries we process from the binlogs. diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index e61d471..781126d 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -54,7 +54,7 @@ func main() { flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)") flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory)") flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)") - flag.BoolVar(&migrationContext.ConcurrentCountTableRows, "concurrent-rowcount", false, "(with --exact-rowcount), when true: count rows after row-copy begins, concurrently, and adjust row estimate later on; defaults false: first count rows, then start row copy") + flag.BoolVar(&migrationContext.ConcurrentCountTableRows, "concurrent-rowcount", true, "(with --exact-rowcount), when true (default): count rows after row-copy begins, concurrently, and adjust row estimate later on; when false: first count rows, then start row copy") flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica") flag.BoolVar(&migrationContext.AllowedMasterMaster, "allow-master-master", false, "explicitly allow running in a master-master setup") flag.BoolVar(&migrationContext.NullableUniqueKeyAllowed, "allow-nullable-unique-key", false, "allow gh-ost to migrate based on a unique key with nullable columns. As long as no NULL values exist, this should be OK. If NULL values exist in chosen key, data may be corrupted. Use at your own risk!") From 90a38b05c23a10374721f446ee21ebc29ee4c58e Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 26 Oct 2016 17:00:27 +0200 Subject: [PATCH 007/104] working for automated deployment builds --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..63f0df9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/.gopath/ +/bin/ +/libexec/ +/.vendor/ From 02c9b97388b1a925a5f489c455ed5d594e64eba6 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 26 Oct 2016 19:55:37 +0200 Subject: [PATCH 008/104] initial work towards automated builds and CI --- script/bootstrap | 16 +++++++ script/build | 21 +++++++++ script/cibuild | 12 +++++ script/cibuild-gh-ost-build-deploy-tarball | 29 ++++++++++++ script/ensure-go-installed | 51 ++++++++++++++++++++++ script/go | 11 +++++ script/godep | 18 ++++++++ script/shim | 22 ++++++++++ 8 files changed, 180 insertions(+) create mode 100755 script/bootstrap create mode 100755 script/build create mode 100755 script/cibuild create mode 100755 script/cibuild-gh-ost-build-deploy-tarball create mode 100755 script/ensure-go-installed create mode 100755 script/go create mode 100755 script/godep create mode 100755 script/shim diff --git a/script/bootstrap b/script/bootstrap new file mode 100755 index 0000000..6ac885b --- /dev/null +++ b/script/bootstrap @@ -0,0 +1,16 @@ +#!/bin/bash + +set -e + +# Make sure we have the version of Go we want to depend on, either from the +# system or one we grab ourselves. +. script/ensure-go-installed + +# Since we want to be able to build this outside of GOPATH, we set it +# up so it points back to us and go is none the wiser + +set -x +rm -rf .gopath +mkdir -p .gopath/src/github.com/github +ln -s "$PWD" .gopath/src/github.com/github/gh-ost +export GOPATH=$PWD/.gopath:$GOPATH diff --git a/script/build b/script/build new file mode 100755 index 0000000..ef04442 --- /dev/null +++ b/script/build @@ -0,0 +1,21 @@ +#!/bin/bash + +set -e + +. script/bootstrap + +mkdir -p bin +bindir="$PWD"/bin +libexecdir="$PWD"/libexec +scriptdir="$PWD"/script + +# We have a few binaries that we want to build, so let's put them into bin/ + +version=$(git rev-parse HEAD) +describe=$(git describe --tags --always --dirty) + +export GOPATH="$PWD/.gopath" +cd .gopath/src/github.com/github/gh-ost + +# We put the binaries into libexec/ and then put the shim with the appropriate name into bin/ +go build -o "$libexecdir/gh-ost" -ldflags "-X main.AppVersion=${version} -X main.BuildDescribe=${describe}" ./go/cmd/gh-ost/main.go diff --git a/script/cibuild b/script/cibuild new file mode 100755 index 0000000..e15c2ce --- /dev/null +++ b/script/cibuild @@ -0,0 +1,12 @@ +#!/bin/bash + +set -e + +. script/bootstrap + +script/build + +export GITBACKUPS_ENV=test + +cd .gopath/src/github.com/github/gh-ost +go test ./go/... diff --git a/script/cibuild-gh-ost-build-deploy-tarball b/script/cibuild-gh-ost-build-deploy-tarball new file mode 100755 index 0000000..b70a207 --- /dev/null +++ b/script/cibuild-gh-ost-build-deploy-tarball @@ -0,0 +1,29 @@ +#!/bin/sh + +set -e + +script/cibuild + +# Get a fresh directory and make sure to delete it afterwards +build_dir=tmp/build +rm -rf $build_dir +mkdir -p $build_dir +trap "rm -rf $build_dir" EXIT + +commit_sha=$(git rev-parse HEAD) + +if [ $(uname -s) = "Darwin" ]; then + build_arch="$(uname -sr | tr -d ' ' | tr '[:upper:]' '[:lower:]')-$(uname -m)" +else + build_arch="$(lsb_release -sc | tr -d ' ' | tr '[:upper:]' '[:lower:]')-$(uname -m)" +fi + +tarball=$build_dir/${commit_sha}-${build_arch}.tar + +# Create the tarball +tar cvf $tarball --mode="ugo=rx" libexec/ bin/ + +# Compress it and copy it to the directory for the CI to upload it +gzip $tarball +mkdir -p "$BUILD_ARTIFACT_DIR"/gh-ost +cp ${tarball}.gz "$BUILD_ARTIFACT_DIR"/gh-ost/ diff --git a/script/ensure-go-installed b/script/ensure-go-installed new file mode 100755 index 0000000..21c49e6 --- /dev/null +++ b/script/ensure-go-installed @@ -0,0 +1,51 @@ +#!/bin/bash + +GO_VERSION=go1.7 + +GO_PKG_DARWIN=${GO_VERSION}.darwin-amd64.pkg +GO_PKG_DARWIN_SHA=e7089843bc7148ffcc147759985b213604d22bb9fd19bd930b515aa981bf1b22 + +GO_PKG_LINUX=${GO_VERSION}.linux-amd64.tar.gz +GO_PKG_LINUX_SHA=702ad90f705365227e902b42d91dd1a40e48ca7f67a2f4b2fd052aaa4295cd95 + +export ROOTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" +cd $ROOTDIR + +# If Go isn't installed globally, setup environment variables for local install. +if [ -z "$(which go)" ] || [ -z "$(go version | grep $GO_VERSION)" ]; then + GODIR="$ROOTDIR/.vendor/go17" + + if [ $(uname -s) = "Darwin" ]; then + export GOROOT="$GODIR/usr/local/go" + else + export GOROOT="$GODIR/go" + fi + + export PATH="$GOROOT/bin:$PATH" +fi + +# Check if local install exists, and install otherwise. +if [ -z "$(which go)" ] || [ -z "$(go version | grep $GO_VERSION)" ]; then + [ -d "$GODIR" ] && rm -rf $GODIR + mkdir -p "$GODIR" + cd "$GODIR"; + + if [ $(uname -s) = "Darwin" ]; then + curl -L -O https://storage.googleapis.com/golang/$GO_PKG_DARWIN + shasum -a256 $GO_PKG_DARWIN | grep $GO_PKG_DARWIN_SHA + xar -xf $GO_PKG_DARWIN + cpio -i < com.googlecode.go.pkg/Payload + else + curl -L -O https://storage.googleapis.com/golang/$GO_PKG_LINUX + shasum -a256 $GO_PKG_LINUX | grep $GO_PKG_LINUX_SHA + tar xf $GO_PKG_LINUX + fi + + # Prove we did something right + echo "$GO_VERSION installed in $GODIR: Go Binary: $(which go)" +fi + +cd $ROOTDIR + +# Configure the new go to be the first go found +export GOPATH=$ROOTDIR/.vendor diff --git a/script/go b/script/go new file mode 100755 index 0000000..309b591 --- /dev/null +++ b/script/go @@ -0,0 +1,11 @@ +#!/bin/bash + +set -e + +. script/bootstrap + +mkdir -p bin +bindir="$PWD"/bin + +cd .gopath/src/github.com/github/gh-ost +go "$@" diff --git a/script/godep b/script/godep new file mode 100755 index 0000000..b6d37a5 --- /dev/null +++ b/script/godep @@ -0,0 +1,18 @@ +#!/bin/bash + +set -e + +PREVGOPATH=$GOPATH + +. script/bootstrap + +mkdir -p bin +bindir="$PWD"/bin + +# Put the GOPATH for the user before our fake one so we can run `godep save ./...` +if [ -n "$PREVGOPATH" ]; then + export GOPATH=$PREVGOPATH:$GOPATH +fi + +cd .gopath/src/github.com/github/gh-ost +godep "$@" diff --git a/script/shim b/script/shim new file mode 100755 index 0000000..95f3968 --- /dev/null +++ b/script/shim @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e + +# This is a shim to sit in front of the binaries and load the appropriate +# environment variables for the environment set in GITBACKUPS_ENV + +# Default to staging for now +GITBACKUPS_ENV=${GITBACKUPS_ENV:-"staging"} + +# The base directory, which includes bin/ and whatever else once deployed +base=$(cd $(dirname $(dirname ${BASH_SOURCE[0]})) && pwd) + +if [ -f "${base}/.app-config/${GITBACKUPS_ENV}.sh" ]; then + source "${base}/.app-config/${GITBACKUPS_ENV}.sh" +fi + +# Now that we have the env variables to let the binary know where to put the +# data, switch into it +myname=$(basename "$0") + +exec "${base}/libexec/${myname}" "$@" From 0bfff66b51b06134bce2b82c7c09d67cbb89eefc Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 26 Oct 2016 20:03:09 +0200 Subject: [PATCH 009/104] more DML tests for latin1 --- localtests/latin1/create.sql | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/localtests/latin1/create.sql b/localtests/latin1/create.sql index 3ca9b47..bc14aaf 100644 --- a/localtests/latin1/create.sql +++ b/localtests/latin1/create.sql @@ -1,7 +1,7 @@ drop table if exists gh_ost_test; create table gh_ost_test ( id int auto_increment, - t varchar(128), + t varchar(128) charset latin1 collate latin1_swedish_ci, primary key(id) ) auto_increment=1 charset latin1 collate latin1_swedish_ci; @@ -17,5 +17,9 @@ create event gh_ost_test begin insert into gh_ost_test values (null, md5(rand())); insert into gh_ost_test values (null, 'átesting'); + insert into gh_ost_test values (null, 'ádelete'); insert into gh_ost_test values (null, 'testátest'); + update gh_ost_test set t='áupdated' order by id desc limit 1; + update gh_ost_test set t='áupdated1' where t='áupdated' order by id desc limit 1; + delete from gh_ost_test where t='ádelete'; end ;; From b76b1c6f97f667655e0fc5ed7cd7dde5d8d808fe Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 11:58:56 +0200 Subject: [PATCH 010/104] binary to be written directly to bindir --- script/build | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/script/build b/script/build index ef04442..d8e8edd 100755 --- a/script/build +++ b/script/build @@ -17,5 +17,5 @@ describe=$(git describe --tags --always --dirty) export GOPATH="$PWD/.gopath" cd .gopath/src/github.com/github/gh-ost -# We put the binaries into libexec/ and then put the shim with the appropriate name into bin/ -go build -o "$libexecdir/gh-ost" -ldflags "-X main.AppVersion=${version} -X main.BuildDescribe=${describe}" ./go/cmd/gh-ost/main.go +# We put the binaries directly into the bindir, because we have no need for shim wrappers +go build -o "$bindir/gh-ost" -ldflags "-X main.AppVersion=${version} -X main.BuildDescribe=${describe}" ./go/cmd/gh-ost/main.go From ef5a535b244f7acfbb0e305686157aff4088942b Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 12:02:08 +0200 Subject: [PATCH 011/104] removed shim, godep files, which are unused in this setup --- script/godep | 18 ------------------ script/shim | 22 ---------------------- 2 files changed, 40 deletions(-) delete mode 100755 script/godep delete mode 100755 script/shim diff --git a/script/godep b/script/godep deleted file mode 100755 index b6d37a5..0000000 --- a/script/godep +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -set -e - -PREVGOPATH=$GOPATH - -. script/bootstrap - -mkdir -p bin -bindir="$PWD"/bin - -# Put the GOPATH for the user before our fake one so we can run `godep save ./...` -if [ -n "$PREVGOPATH" ]; then - export GOPATH=$PREVGOPATH:$GOPATH -fi - -cd .gopath/src/github.com/github/gh-ost -godep "$@" diff --git a/script/shim b/script/shim deleted file mode 100755 index 95f3968..0000000 --- a/script/shim +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -set -e - -# This is a shim to sit in front of the binaries and load the appropriate -# environment variables for the environment set in GITBACKUPS_ENV - -# Default to staging for now -GITBACKUPS_ENV=${GITBACKUPS_ENV:-"staging"} - -# The base directory, which includes bin/ and whatever else once deployed -base=$(cd $(dirname $(dirname ${BASH_SOURCE[0]})) && pwd) - -if [ -f "${base}/.app-config/${GITBACKUPS_ENV}.sh" ]; then - source "${base}/.app-config/${GITBACKUPS_ENV}.sh" -fi - -# Now that we have the env variables to let the binary know where to put the -# data, switch into it -myname=$(basename "$0") - -exec "${base}/libexec/${myname}" "$@" From a97c25e9dff7ad526f20c1405da52e40c75c330b Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 12:03:24 +0200 Subject: [PATCH 012/104] removed libexec references --- script/build | 1 - script/cibuild-gh-ost-build-deploy-tarball | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/script/build b/script/build index d8e8edd..24be5d1 100755 --- a/script/build +++ b/script/build @@ -6,7 +6,6 @@ set -e mkdir -p bin bindir="$PWD"/bin -libexecdir="$PWD"/libexec scriptdir="$PWD"/script # We have a few binaries that we want to build, so let's put them into bin/ diff --git a/script/cibuild-gh-ost-build-deploy-tarball b/script/cibuild-gh-ost-build-deploy-tarball index b70a207..13560cc 100755 --- a/script/cibuild-gh-ost-build-deploy-tarball +++ b/script/cibuild-gh-ost-build-deploy-tarball @@ -21,7 +21,7 @@ fi tarball=$build_dir/${commit_sha}-${build_arch}.tar # Create the tarball -tar cvf $tarball --mode="ugo=rx" libexec/ bin/ +tar cvf $tarball --mode="ugo=rx" bin/ # Compress it and copy it to the directory for the CI to upload it gzip $tarball From 7b63b4a275aaed7ea0cf9a87b6579b6e22596a70 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 13:52:37 +0200 Subject: [PATCH 013/104] proper cleanup of streamer connection --- go/base/context.go | 1 + go/binlog/gomysql_reader.go | 8 ++++++++ go/logic/migrator.go | 6 +++++- go/logic/streamer.go | 6 ++++++ 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/go/base/context.go b/go/base/context.go index fe201e4..f405fdc 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -145,6 +145,7 @@ type MigrationContext struct { AllEventsUpToLockProcessedInjectedFlag int64 CleanupImminentFlag int64 UserCommandedUnpostponeFlag int64 + CutOverCompleteFlag int64 PanicAbort chan error OriginalTableColumnsOnApplier *sql.ColumnList diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index ac6d890..61f8e3b 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -120,6 +120,9 @@ func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEven // StreamEvents func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesChannel chan<- *BinlogEntry) error { + if canStopStreaming() { + return nil + } for { if canStopStreaming() { break @@ -150,3 +153,8 @@ func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesCha return nil } + +func (this *GoMySQLReader) Close() error { + this.binlogSyncer.Close() + return nil +} diff --git a/go/logic/migrator.go b/go/logic/migrator.go index b0a6246..c135ef8 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -171,7 +171,7 @@ func (this *Migrator) consumeRowCopyComplete() { } func (this *Migrator) canStopStreaming() bool { - return false + return atomic.LoadInt64(&this.migrationContext.CutOverCompleteFlag) != 0 } // onChangelogStateEvent is called when a binlog event operation on the changelog table is intercepted. @@ -345,6 +345,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.cutOver(); err != nil { return err } + atomic.StoreInt64(&this.migrationContext.CutOverCompleteFlag, 1) if err := this.finalCleanup(); err != nil { return nil @@ -1058,6 +1059,9 @@ func (this *Migrator) finalCleanup() error { log.Errore(err) } } + if err := this.eventsStreamer.Close(); err != nil { + log.Errore(err) + } if err := this.retryOperation(this.applier.DropChangelogTable); err != nil { return err diff --git a/go/logic/streamer.go b/go/logic/streamer.go index 86faab1..dc5ba60 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -217,3 +217,9 @@ func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { } } } + +func (this *EventsStreamer) Close() (err error) { + err = this.binlogReader.Close() + log.Infof("Closed streamer connection. err=%+v", err) + return err +} From 18d12e382d85e90067bfcccfaf702a94ae3a44ec Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 13:53:51 +0200 Subject: [PATCH 014/104] externalized -RELEASE_VERSION --- RELEASE_VERSION | 1 + build.sh | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 RELEASE_VERSION diff --git a/RELEASE_VERSION b/RELEASE_VERSION new file mode 100644 index 0000000..8b54409 --- /dev/null +++ b/RELEASE_VERSION @@ -0,0 +1 @@ +1.0.28 diff --git a/build.sh b/build.sh index c12b6a1..3e1ce6f 100755 --- a/build.sh +++ b/build.sh @@ -2,7 +2,7 @@ # # -RELEASE_VERSION="1.0.26" +RELEASE_VERSION=$(cat RELEASE_VERSION) function build { osname=$1 From 7fa5e405d41a337cb1d9362c0390219607805a99 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 14:51:38 +0200 Subject: [PATCH 015/104] avoid writing heartbeat when throttle commanded by user when throttling on user command there really is no need for injecting heartbeat. The user commanded, therefore gh-ost complies and trusts the reasoning for throttling. What this will allow is complete quiet time. This, in turn, will allow such features as relocating via orchestrator/pseudo-gtid at time of throttling --- go/base/context.go | 19 +++++++++++++++---- go/logic/applier.go | 3 +++ go/logic/migrator.go | 2 +- go/logic/throttler.go | 40 ++++++++++++++++++++-------------------- 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index fe201e4..eff8d8e 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -37,6 +37,13 @@ const ( CutOverTwoStep = iota ) +type ThrottleReasonHint string + +const ( + NoThrottleReasonHint ThrottleReasonHint = "NoThrottleReasonHint" + UserCommandThrottleReasonHint = "UserCommandThrottleReasonHint" +) + var ( envVariableRegexp = regexp.MustCompile("[$][{](.*)[}]") ) @@ -44,12 +51,14 @@ var ( type ThrottleCheckResult struct { ShouldThrottle bool Reason string + ReasonHint ThrottleReasonHint } -func NewThrottleCheckResult(throttle bool, reason string) *ThrottleCheckResult { +func NewThrottleCheckResult(throttle bool, reason string, reasonHint ThrottleReasonHint) *ThrottleCheckResult { return &ThrottleCheckResult{ ShouldThrottle: throttle, Reason: reason, + ReasonHint: reasonHint, } } @@ -138,6 +147,7 @@ type MigrationContext struct { TotalDMLEventsApplied int64 isThrottled bool throttleReason string + throttleReasonHint ThrottleReasonHint throttleGeneralCheckResult ThrottleCheckResult throttleMutex *sync.Mutex IsPostponingCutOver int64 @@ -416,17 +426,18 @@ func (this *MigrationContext) GetThrottleGeneralCheckResult() *ThrottleCheckResu return &result } -func (this *MigrationContext) SetThrottled(throttle bool, reason string) { +func (this *MigrationContext) SetThrottled(throttle bool, reason string, reasonHint ThrottleReasonHint) { this.throttleMutex.Lock() defer this.throttleMutex.Unlock() this.isThrottled = throttle this.throttleReason = reason + this.throttleReasonHint = reasonHint } -func (this *MigrationContext) IsThrottled() (bool, string) { +func (this *MigrationContext) IsThrottled() (bool, string, ThrottleReasonHint) { this.throttleMutex.Lock() defer this.throttleMutex.Unlock() - return this.isThrottled, this.throttleReason + return this.isThrottled, this.throttleReason, this.throttleReasonHint } func (this *MigrationContext) GetReplicationLagQuery() string { diff --git a/go/logic/applier.go b/go/logic/applier.go index 286728a..73d98e8 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -305,6 +305,9 @@ func (this *Applier) InitiateHeartbeat() { // Generally speaking, we would issue a goroutine, but I'd actually rather // have this block the loop rather than spam the master in the event something // goes wrong + if throttle, _, reasonHint := this.migrationContext.IsThrottled(); throttle && (reasonHint == base.UserCommandThrottleReasonHint) { + continue + } if err := injectHeartbeat(); err != nil { return } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index b0a6246..6a1a778 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -803,7 +803,7 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { } else if atomic.LoadInt64(&this.migrationContext.IsPostponingCutOver) > 0 { eta = "due" state = "postponing cut-over" - } else if isThrottled, throttleReason := this.migrationContext.IsThrottled(); isThrottled { + } else if isThrottled, throttleReason, _ := this.migrationContext.IsThrottled(); isThrottled { state = fmt.Sprintf("throttled, %s", throttleReason) } diff --git a/go/logic/throttler.go b/go/logic/throttler.go index ee22931..482087c 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -34,16 +34,16 @@ func NewThrottler(applier *Applier, inspector *Inspector) *Throttler { // shouldThrottle performs checks to see whether we should currently be throttling. // It merely observes the metrics collected by other components, it does not issue // its own metric collection. -func (this *Throttler) shouldThrottle() (result bool, reason string) { +func (this *Throttler) shouldThrottle() (result bool, reason string, reasonHint base.ThrottleReasonHint) { generalCheckResult := this.migrationContext.GetThrottleGeneralCheckResult() if generalCheckResult.ShouldThrottle { - return generalCheckResult.ShouldThrottle, generalCheckResult.Reason + return generalCheckResult.ShouldThrottle, generalCheckResult.Reason, generalCheckResult.ReasonHint } // Replication lag throttle maxLagMillisecondsThrottleThreshold := atomic.LoadInt64(&this.migrationContext.MaxLagMillisecondsThrottleThreshold) lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) if time.Duration(lag) > time.Duration(maxLagMillisecondsThrottleThreshold)*time.Millisecond { - return true, fmt.Sprintf("lag=%fs", time.Duration(lag).Seconds()) + return true, fmt.Sprintf("lag=%fs", time.Duration(lag).Seconds()), base.NoThrottleReasonHint } checkThrottleControlReplicas := true if (this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica) && (atomic.LoadInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag) > 0) { @@ -52,14 +52,14 @@ func (this *Throttler) shouldThrottle() (result bool, reason string) { if checkThrottleControlReplicas { lagResult := this.migrationContext.GetControlReplicasLagResult() if lagResult.Err != nil { - return true, fmt.Sprintf("%+v %+v", lagResult.Key, lagResult.Err) + return true, fmt.Sprintf("%+v %+v", lagResult.Key, lagResult.Err), base.NoThrottleReasonHint } if lagResult.Lag > time.Duration(maxLagMillisecondsThrottleThreshold)*time.Millisecond { - return true, fmt.Sprintf("%+v replica-lag=%fs", lagResult.Key, lagResult.Lag.Seconds()) + return true, fmt.Sprintf("%+v replica-lag=%fs", lagResult.Key, lagResult.Lag.Seconds()), base.NoThrottleReasonHint } } // Got here? No metrics indicates we need throttling. - return false, "" + return false, "", base.NoThrottleReasonHint } // parseChangelogHeartbeat is called when a heartbeat event is intercepted @@ -147,8 +147,8 @@ func (this *Throttler) criticalLoadIsMet() (met bool, variableName string, value // collectGeneralThrottleMetrics reads the once-per-sec metrics, and stores them onto this.migrationContext func (this *Throttler) collectGeneralThrottleMetrics() error { - setThrottle := func(throttle bool, reason string) error { - this.migrationContext.SetThrottleGeneralCheckResult(base.NewThrottleCheckResult(throttle, reason)) + setThrottle := func(throttle bool, reason string, reasonHint base.ThrottleReasonHint) error { + this.migrationContext.SetThrottleGeneralCheckResult(base.NewThrottleCheckResult(throttle, reason, reasonHint)) return nil } @@ -161,7 +161,7 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { criticalLoadMet, variableName, value, threshold, err := this.criticalLoadIsMet() if err != nil { - return setThrottle(true, fmt.Sprintf("%s %s", variableName, err)) + return setThrottle(true, fmt.Sprintf("%s %s", variableName, err), base.NoThrottleReasonHint) } if criticalLoadMet && this.migrationContext.CriticalLoadIntervalMilliseconds == 0 { this.migrationContext.PanicAbort <- fmt.Errorf("critical-load met: %s=%d, >=%d", variableName, value, threshold) @@ -181,18 +181,18 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { // User-based throttle if atomic.LoadInt64(&this.migrationContext.ThrottleCommandedByUser) > 0 { - return setThrottle(true, "commanded by user") + return setThrottle(true, "commanded by user", base.UserCommandThrottleReasonHint) } if this.migrationContext.ThrottleFlagFile != "" { if base.FileExists(this.migrationContext.ThrottleFlagFile) { // Throttle file defined and exists! - return setThrottle(true, "flag-file") + return setThrottle(true, "flag-file", base.NoThrottleReasonHint) } } if this.migrationContext.ThrottleAdditionalFlagFile != "" { if base.FileExists(this.migrationContext.ThrottleAdditionalFlagFile) { // 2nd Throttle file defined and exists! - return setThrottle(true, "flag-file") + return setThrottle(true, "flag-file", base.NoThrottleReasonHint) } } @@ -200,19 +200,19 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { for variableName, threshold := range maxLoad { value, err := this.applier.ShowStatusVariable(variableName) if err != nil { - return setThrottle(true, fmt.Sprintf("%s %s", variableName, err)) + return setThrottle(true, fmt.Sprintf("%s %s", variableName, err), base.NoThrottleReasonHint) } if value >= threshold { - return setThrottle(true, fmt.Sprintf("max-load %s=%d >= %d", variableName, value, threshold)) + return setThrottle(true, fmt.Sprintf("max-load %s=%d >= %d", variableName, value, threshold), base.NoThrottleReasonHint) } } if this.migrationContext.GetThrottleQuery() != "" { if res, _ := this.applier.ExecuteThrottleQuery(); res > 0 { - return setThrottle(true, "throttle-query") + return setThrottle(true, "throttle-query", base.NoThrottleReasonHint) } } - return setThrottle(false, "") + return setThrottle(false, "", base.NoThrottleReasonHint) } // initiateThrottlerMetrics initiates the various processes that collect measurements @@ -237,8 +237,8 @@ func (this *Throttler) initiateThrottlerChecks() error { throttlerTick := time.Tick(100 * time.Millisecond) throttlerFunction := func() { - alreadyThrottling, currentReason := this.migrationContext.IsThrottled() - shouldThrottle, throttleReason := this.shouldThrottle() + alreadyThrottling, currentReason, _ := this.migrationContext.IsThrottled() + shouldThrottle, throttleReason, throttleReasonHint := this.shouldThrottle() if shouldThrottle && !alreadyThrottling { // New throttling this.applier.WriteAndLogChangelog("throttle", throttleReason) @@ -249,7 +249,7 @@ func (this *Throttler) initiateThrottlerChecks() error { // End of throttling this.applier.WriteAndLogChangelog("throttle", "done throttling") } - this.migrationContext.SetThrottled(shouldThrottle, throttleReason) + this.migrationContext.SetThrottled(shouldThrottle, throttleReason, throttleReasonHint) } throttlerFunction() for range throttlerTick { @@ -265,7 +265,7 @@ func (this *Throttler) throttle(onThrottled func()) { for { // IsThrottled() is non-blocking; the throttling decision making takes place asynchronously. // Therefore calling IsThrottled() is cheap - if shouldThrottle, _ := this.migrationContext.IsThrottled(); !shouldThrottle { + if shouldThrottle, _, _ := this.migrationContext.IsThrottled(); !shouldThrottle { return } if onThrottled != nil { From c5cf552dadd1b211d2c395cb1e6b4246fab07039 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 15:05:25 +0200 Subject: [PATCH 016/104] fixed testing dependenies --- .../github.com/outbrain/golib/tests/spec.go | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 vendor/github.com/outbrain/golib/tests/spec.go diff --git a/vendor/github.com/outbrain/golib/tests/spec.go b/vendor/github.com/outbrain/golib/tests/spec.go new file mode 100644 index 0000000..27b15b6 --- /dev/null +++ b/vendor/github.com/outbrain/golib/tests/spec.go @@ -0,0 +1,76 @@ +package tests + +import ( + "testing" +) + +// Spec is an access point to test Expections +type Spec struct { + t *testing.T +} + +// S generates a spec. You will want to use it once in a test file, once in a test or once per each check +func S(t *testing.T) *Spec { + return &Spec{t: t} +} + +// ExpectNil expects given value to be nil, or errors +func (spec *Spec) ExpectNil(actual interface{}) { + if actual == nil { + return + } + spec.t.Errorf("Expected %+v to be nil", actual) +} + +// ExpectNotNil expects given value to be not nil, or errors +func (spec *Spec) ExpectNotNil(actual interface{}) { + if actual != nil { + return + } + spec.t.Errorf("Expected %+v to be not nil", actual) +} + +// ExpectEquals expects given values to be equal (comparison via `==`), or errors +func (spec *Spec) ExpectEquals(actual, value interface{}) { + if actual == value { + return + } + spec.t.Errorf("Expected %+v, got %+v", value, actual) +} + +// ExpectNotEquals expects given values to be nonequal (comparison via `==`), or errors +func (spec *Spec) ExpectNotEquals(actual, value interface{}) { + if !(actual == value) { + return + } + spec.t.Errorf("Expected not %+v", value) +} + +// ExpectEqualsAny expects given actual to equal (comparison via `==`) at least one of given values, or errors +func (spec *Spec) ExpectEqualsAny(actual interface{}, values ...interface{}) { + for _, value := range values { + if actual == value { + return + } + } + spec.t.Errorf("Expected %+v to equal any of given values", actual) +} + +// ExpectNotEqualsAny expects given actual to be nonequal (comparison via `==`)tp any of given values, or errors +func (spec *Spec) ExpectNotEqualsAny(actual interface{}, values ...interface{}) { + for _, value := range values { + if actual == value { + spec.t.Errorf("Expected not %+v", value) + } + } +} + +// ExpectFalse expects given values to be false, or errors +func (spec *Spec) ExpectFalse(actual interface{}) { + spec.ExpectEquals(actual, false) +} + +// ExpectTrue expects given values to be true, or errors +func (spec *Spec) ExpectTrue(actual interface{}) { + spec.ExpectEquals(actual, true) +} From 40b1d9a3199fc51d8d2e251cde814201512f26ef Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 20:37:03 +0200 Subject: [PATCH 017/104] gofmt testing - more verbose on intended tests --- script/cibuild | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/script/cibuild b/script/cibuild index e15c2ce..7e757b5 100755 --- a/script/cibuild +++ b/script/cibuild @@ -4,9 +4,14 @@ set -e . script/bootstrap +echo "Verifying code is formatted via 'gofmt -s -w go/'" +gofmt -s -w go/ +git diff --exit-code --quiet + +echo "Building" script/build -export GITBACKUPS_ENV=test - cd .gopath/src/github.com/github/gh-ost + +echo "Running unit tests" go test ./go/... From 5aad45e3bc8f263fc209f58a8aae104b45fc3f41 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 27 Oct 2016 20:38:45 +0200 Subject: [PATCH 018/104] formatted code via gofmt --- go/binlog/binlog_reader.go | 2 -- go/binlog/gomysql_reader.go | 2 -- 2 files changed, 4 deletions(-) diff --git a/go/binlog/binlog_reader.go b/go/binlog/binlog_reader.go index e7dc2cf..b4fe000 100644 --- a/go/binlog/binlog_reader.go +++ b/go/binlog/binlog_reader.go @@ -5,8 +5,6 @@ package binlog -import () - // BinlogReader is a general interface whose implementations can choose their methods of reading // a binary log file and parsing it into binlog entries type BinlogReader interface { diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index ac6d890..35f93f4 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -17,8 +17,6 @@ import ( "github.com/siddontang/go-mysql/replication" ) -var () - const ( serverId = 99999 ) From bb22431b83bdc21a9492de989887daeca6369842 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 31 Oct 2016 10:25:34 +0100 Subject: [PATCH 019/104] fixed log_slave_updates check logic --- go/logic/inspect.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 595d76c..a6ecd3d 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -335,12 +335,27 @@ func (this *Inspector) validateLogSlaveUpdates() error { if err := this.db.QueryRow(query).Scan(&logSlaveUpdates); err != nil { return err } - if !logSlaveUpdates && !this.migrationContext.InspectorIsAlsoApplier() && !this.migrationContext.IsTungsten { - return fmt.Errorf("%s:%d must have log_slave_updates enabled", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + + if logSlaveUpdates { + log.Infof("log_slave_updates validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + return nil } - log.Infof("binary logs updates validated on %s:%d", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) - return nil + if this.migrationContext.IsTungsten { + log.Warning("log_slave_updates not found on %s:%d, but --tungsten provided, so I'm proceeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + return nil + } + + if this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica { + return fmt.Errorf("%s:%d must have log_slave_updates enabled for testing/migrating on replica", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + } + + if this.migrationContext.InspectorIsAlsoApplier() { + log.Warning("log_slave_updates not found on %s:%d, but executing directly on master, so I'm proceeeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + return nil + } + + return fmt.Errorf("%s:%d must have log_slave_updates enabled for executing migration", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) } // validateTable makes sure the table we need to operate on actually exists From 946427a20b39ec78c1b5e25a105b36b436aada27 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 31 Oct 2016 14:01:31 +0100 Subject: [PATCH 020/104] added ugluy patch for making this work on jessie --- script/cibuild-gh-ost-build-deploy-tarball | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/script/cibuild-gh-ost-build-deploy-tarball b/script/cibuild-gh-ost-build-deploy-tarball index 13560cc..692b42b 100755 --- a/script/cibuild-gh-ost-build-deploy-tarball +++ b/script/cibuild-gh-ost-build-deploy-tarball @@ -27,3 +27,11 @@ tar cvf $tarball --mode="ugo=rx" bin/ gzip $tarball mkdir -p "$BUILD_ARTIFACT_DIR"/gh-ost cp ${tarball}.gz "$BUILD_ARTIFACT_DIR"/gh-ost/ + +### HACK HACK HACK ### +# Blame @carlosmn. In the good way. +# We don't have any jessie machines for building, but a pure-Go binary depends +# on a version of libc and ld which are widely available, so we can copy the +# tarball over with jessie in its name so we can deploy it on jessie machines. +jessie_tarball_name=$(echo $(basename "${tarball}") | sed s/-precise-/-jessie-/) +cp ${tarball}.gz "$BUILD_ARTIFACT_DIR/gh-ost/${jessie_tarball_name}.gz" From 0e7e15804fa1ad19cafe9e35357c81c3e8bc03d1 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 31 Oct 2016 15:09:02 +0100 Subject: [PATCH 021/104] handling subsecond resolution DATETIME in binary logs added localtests for subsecond in DATETIME and TIMESTAMP --- localtests/datetime-submillis/create.sql | 28 +++++++++++++++++++ .../go-mysql/replication/row_event.go | 5 +++- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 localtests/datetime-submillis/create.sql diff --git a/localtests/datetime-submillis/create.sql b/localtests/datetime-submillis/create.sql new file mode 100644 index 0000000..b4e0b0b --- /dev/null +++ b/localtests/datetime-submillis/create.sql @@ -0,0 +1,28 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + dt0 datetime(6), + dt1 datetime(6), + ts2 timestamp(6), + updated tinyint unsigned default 0, + primary key(id), + key i_idx(i) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11, now(), now(), now(), 0); + update gh_ost_test set dt1='2016-10-31 11:22:33.444', updated = 1 where i = 11 order by id desc limit 1; + + insert into gh_ost_test values (null, 13, now(), now(), now(), 0); + update gh_ost_test set ts1='2016-11-01 11:22:33.444', updated = 1 where i = 13 order by id desc limit 1; +end ;; diff --git a/vendor/github.com/siddontang/go-mysql/replication/row_event.go b/vendor/github.com/siddontang/go-mysql/replication/row_event.go index 9506b1e..ff6c913 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/row_event.go +++ b/vendor/github.com/siddontang/go-mysql/replication/row_event.go @@ -642,7 +642,7 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { } //ingore second part, no precision now - //var secPart int64 = tmp % (1 << 24) + var secPart int64 = tmp % (1 << 24) ymdhms := tmp >> 24 ymd := ymdhms >> 17 @@ -657,6 +657,9 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { minute := int((hms >> 6) % (1 << 6)) hour := int((hms >> 12)) + if secPart != 0 { + return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%d", year, month, day, hour, minute, second, secPart), n, nil // commented by Shlomi Noach. Yes I know about `git blame` + } return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second), n, nil // commented by Shlomi Noach. Yes I know about `git blame` } From 3be76cd0ff9fae0349a01ef9625d3bf464acd8cf Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 2 Nov 2016 10:41:40 +0100 Subject: [PATCH 022/104] added limitations clarifications: JSON and FEDERATED --- doc/requirements-and-limitations.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/doc/requirements-and-limitations.md b/doc/requirements-and-limitations.md index d4dc50b..81e7226 100644 --- a/doc/requirements-and-limitations.md +++ b/doc/requirements-and-limitations.md @@ -4,7 +4,7 @@ - You will need to have one server serving Row Based Replication (RBR) format binary logs. Right now `FULL` row image is supported. `MINIMAL` to be supported in the near future. `gh-ost` prefers to work with replicas. You may [still have your master configured with Statement Based Replication](migrating-with-sbr.md) (SBR). -- If you are using a replica, the table must have an identical schema between the master and replica. +- If you are using a replica, the table must have an identical schema between the master and replica. - `gh-ost` requires an account with these privileges: @@ -28,6 +28,8 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - MySQL 5.7 generated columns are not supported. They may be supported in the future. +- MySQL 5.7 `JSON` columns are not supported. They are likely to be supported shortly. + - The two _before_ & _after_ tables must share some `UNIQUE KEY`. Such key would be used by `gh-ost` to iterate the table. - As an example, if your table has a single `UNIQUE KEY` and no `PRIMARY KEY`, and you wish to replace it with a `PRIMARY KEY`, you will need two migrations: one to add the `PRIMARY KEY` (this migration will use the existing `UNIQUE KEY`), another to drop the now redundant `UNIQUE KEY` (this migration will use the `PRIMARY KEY`). @@ -43,4 +45,7 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - Multisource is not supported when migrating via replica. It _should_ work (but never tested) when connecting directly to master (`--allow-on-master`) - Master-master setup is only supported in active-passive setup. Active-active (where table is being written to on both masters concurrently) is unsupported. It may be supported in the future. + - If you have en `enum` field as part of your migration key (typically the `PRIMARY KEY`), migration performance will be degraded and potentially bad. [Read more](https://github.com/github/gh-ost/pull/277#issuecomment-254811520) + +- Migrating a `FEDERATED` table is unsupported and is irrelevant to the problem `gh-ost` tackles. From 88ffb75b8cf05f1181cec02f522a6582624d1f1c Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 2 Nov 2016 12:48:35 +0100 Subject: [PATCH 023/104] reading and reporting replication lag before waiting on initial replication event --- go/logic/inspect.go | 10 ++++++++++ go/logic/migrator.go | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index a6ecd3d..5542cc1 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -11,6 +11,7 @@ import ( "reflect" "strings" "sync/atomic" + "time" "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/mysql" @@ -690,3 +691,12 @@ func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.Connect visitedKeys := mysql.NewInstanceKeyMap() return mysql.GetMasterConnectionConfigSafe(this.connectionConfig, visitedKeys, this.migrationContext.AllowedMasterMaster) } + +func (this *Inspector) getReplicationLag() (replicationLag time.Duration, err error) { + replicationLagQuery := this.migrationContext.GetReplicationLagQuery() + lag, err := mysql.GetReplicationLag( + this.migrationContext.InspectorConnectionConfig, + replicationLagQuery, + ) + return lag, err +} diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 34728c8..88f5423 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -291,7 +291,8 @@ func (this *Migrator) Migrate() (err error) { return err } - log.Infof("Waiting for ghost table to be migrated") + initialLag, _ := this.inspector.getReplicationLag() + log.Infof("Waiting for ghost table to be migrated. Current lag is %+v", initialLag) <-this.ghostTableMigrated log.Debugf("ghost table migrated") // Yay! We now know the Ghost and Changelog tables are good to examine! From c2d4f624afb85c6829385004a10d14fecdc29c9c Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 3 Nov 2016 12:14:53 +0100 Subject: [PATCH 024/104] simplified code --- go/logic/inspect.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 5542cc1..fb77746 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -694,9 +694,9 @@ func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.Connect func (this *Inspector) getReplicationLag() (replicationLag time.Duration, err error) { replicationLagQuery := this.migrationContext.GetReplicationLagQuery() - lag, err := mysql.GetReplicationLag( + replicationLag, err = mysql.GetReplicationLag( this.migrationContext.InspectorConnectionConfig, replicationLagQuery, ) - return lag, err + return replicationLag, err } From 694bf77e862dca9388461fce6895cfcb72eb7d28 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 3 Nov 2016 16:50:16 +0100 Subject: [PATCH 025/104] adding link to Questio labeled issues --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 53d017a..451b5b2 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ Also see: - [requirements and limitations](doc/requirements-and-limitations.md) - [what if?](doc/what-if.md) - [the fine print](doc/the-fine-print.md) +- [Questions](https://github.com/github/gh-ost/issues?q=label%3Aquestion) ## What's in a name? From 668c13cf201ca4a6d9f208ba0707b7fde92849a8 Mon Sep 17 00:00:00 2001 From: Jonathan Camenisch Date: Fri, 4 Nov 2016 11:04:43 -0400 Subject: [PATCH 026/104] Small typo correction --- doc/triggerless-design.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/triggerless-design.md b/doc/triggerless-design.md index e2d89e2..510a301 100644 --- a/doc/triggerless-design.md +++ b/doc/triggerless-design.md @@ -47,7 +47,7 @@ Initial setup is a no-concurrency operation - Applying `alter` on ghost table - Comparing structure of original & ghost table. Looking for shared columns, shared unique keys, validating foreign keys. Choosing shared unique key, the key by which we chunk the table and process it. - Setting up the binlog listener; begin listening on changelog events -- Injecting a "good to go" ebtry onto the changelog table (to be intercepted via binary logs) +- Injecting a "good to go" entry onto the changelog table (to be intercepted via binary logs) - Begin listening on binlog events for original table DMLs - Reading original table's chosen key min/max values From ee447ad56077e2e3421f21db99f2079491e2236e Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 17 Nov 2016 15:20:44 +0100 Subject: [PATCH 027/104] waitForEventsUpToLock timeout more info on AllEventsUpToLockProcessed, before and after injecting/intercepting --- go/logic/migrator.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 88f5423..50ed38f 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -78,8 +78,8 @@ func NewMigrator() *Migrator { parser: sql.NewParser(), ghostTableMigrated: make(chan bool), firstThrottlingCollected: make(chan bool, 1), - rowCopyComplete: make(chan bool), - allEventsUpToLockProcessed: make(chan bool), + rowCopyComplete: make(chan bool, 1), + allEventsUpToLockProcessed: make(chan bool, 1), copyRowsQueue: make(chan tableWriteFunc), applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), @@ -181,6 +181,7 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return nil } changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) + log.Infof("Intercepted changelog state %s", changelogState) switch changelogState { case GhostTableMigrated: { @@ -206,7 +207,7 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return fmt.Errorf("Unknown changelog state: %+v", changelogState) } } - log.Debugf("Received state %+v", changelogState) + log.Infof("Handled changelog state %s", changelogState) return nil } @@ -445,6 +446,8 @@ func (this *Migrator) cutOver() (err error) { // Inject the "AllEventsUpToLockProcessed" state hint, wait for it to appear in the binary logs, // make sure the queue is drained. func (this *Migrator) waitForEventsUpToLock() (err error) { + timeout := time.NewTimer(time.Second * time.Duration(this.migrationContext.CutOverLockTimeoutSeconds)) + this.migrationContext.MarkPointOfInterest() waitForEventsUpToLockStartTime := time.Now() @@ -454,7 +457,16 @@ func (this *Migrator) waitForEventsUpToLock() (err error) { } log.Infof("Waiting for events up to lock") atomic.StoreInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 1) - <-this.allEventsUpToLockProcessed + select { + case <-timeout.C: + { + return log.Errorf("Timeout while waiting for events up to lock") + } + case <-this.allEventsUpToLockProcessed: + { + log.Infof("Waiting for events up to lock: got allEventsUpToLockProcessed.") + } + } waitForEventsUpToLockDuration := time.Since(waitForEventsUpToLockStartTime) log.Infof("Done waiting for events up to lock; duration=%+v", waitForEventsUpToLockDuration) From ef874b855163270d4691b050a40a790b8c55819f Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 17 Nov 2016 15:50:54 +0100 Subject: [PATCH 028/104] AllEventsUpToLockProcessed uses unique signature --- go/logic/migrator.go | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 50ed38f..eabc5fa 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -11,6 +11,7 @@ import ( "math" "os" "os/signal" + "strings" "sync/atomic" "syscall" "time" @@ -60,7 +61,7 @@ type Migrator struct { firstThrottlingCollected chan bool ghostTableMigrated chan bool rowCopyComplete chan bool - allEventsUpToLockProcessed chan bool + allEventsUpToLockProcessed chan string rowCopyCompleteFlag int64 inCutOverCriticalActionFlag int64 @@ -78,8 +79,8 @@ func NewMigrator() *Migrator { parser: sql.NewParser(), ghostTableMigrated: make(chan bool), firstThrottlingCollected: make(chan bool, 1), - rowCopyComplete: make(chan bool, 1), - allEventsUpToLockProcessed: make(chan bool, 1), + rowCopyComplete: make(chan bool), + allEventsUpToLockProcessed: make(chan string), copyRowsQueue: make(chan tableWriteFunc), applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), @@ -180,7 +181,8 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" { return nil } - changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) + changelogStateString := dmlEvent.NewColumnValues.StringColumn(3) + changelogState := ChangelogState(strings.Split(changelogStateString, ":")[0]) log.Infof("Intercepted changelog state %s", changelogState) switch changelogState { case GhostTableMigrated: @@ -190,7 +192,7 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er case AllEventsUpToLockProcessed: { applyEventFunc := func() error { - this.allEventsUpToLockProcessed <- true + this.allEventsUpToLockProcessed <- changelogStateString return nil } // at this point we know all events up to lock have been read from the streamer, @@ -451,20 +453,28 @@ func (this *Migrator) waitForEventsUpToLock() (err error) { this.migrationContext.MarkPointOfInterest() waitForEventsUpToLockStartTime := time.Now() - log.Infof("Writing changelog state: %+v", AllEventsUpToLockProcessed) - if _, err := this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed)); err != nil { + allEventsUpToLockProcessedChallenge := fmt.Sprintf("%s:%d", string(AllEventsUpToLockProcessed), waitForEventsUpToLockStartTime.UnixNano()) + log.Infof("Writing changelog state: %+v", allEventsUpToLockProcessedChallenge) + if _, err := this.applier.WriteChangelogState(allEventsUpToLockProcessedChallenge); err != nil { return err } log.Infof("Waiting for events up to lock") atomic.StoreInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 1) - select { - case <-timeout.C: - { - return log.Errorf("Timeout while waiting for events up to lock") - } - case <-this.allEventsUpToLockProcessed: - { - log.Infof("Waiting for events up to lock: got allEventsUpToLockProcessed.") + for found := false; !found; { + select { + case <-timeout.C: + { + return log.Errorf("Timeout while waiting for events up to lock") + } + case state := <-this.allEventsUpToLockProcessed: + { + if state == allEventsUpToLockProcessedChallenge { + log.Infof("Waiting for events up to lock: got %s", state) + found = true + } else { + log.Infof("Waiting for events up to lock: skipping %s", state) + } + } } } waitForEventsUpToLockDuration := time.Since(waitForEventsUpToLockStartTime) From 8d987b5aaf52db5428f54a7924b1f4671280aa46 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 17 Nov 2016 15:56:59 +0100 Subject: [PATCH 029/104] extracted parsing of ChangelogState --- go/logic/migrator.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index eabc5fa..20475bf 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -31,6 +31,10 @@ const ( AllEventsUpToLockProcessed = "AllEventsUpToLockProcessed" ) +func ReadChangelogState(s string) ChangelogState { + return ChangelogState(strings.Split(s, ":")[0]) +} + type tableWriteFunc func() error const ( @@ -182,7 +186,7 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return nil } changelogStateString := dmlEvent.NewColumnValues.StringColumn(3) - changelogState := ChangelogState(strings.Split(changelogStateString, ":")[0]) + changelogState := ReadChangelogState(changelogStateString) log.Infof("Intercepted changelog state %s", changelogState) switch changelogState { case GhostTableMigrated: From b00cae11fa408e9e3a58da34b50cdd79db53742d Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 17 Nov 2016 17:10:17 +0100 Subject: [PATCH 030/104] retry cut-over --- go/logic/applier.go | 3 +-- go/logic/migrator.go | 27 ++++++++++++--------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/go/logic/applier.go b/go/logic/applier.go index 73d98e8..a9735b6 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -694,8 +694,7 @@ func (this *Applier) DropAtomicCutOverSentryTableIfExists() error { return this.dropTable(tableName) } -// DropAtomicCutOverSentryTableIfExists checks if the "old" table name -// happens to be a cut-over magic table; if so, it drops it. +// CreateAtomicCutOverSentryTable func (this *Applier) CreateAtomicCutOverSentryTable() error { if err := this.DropAtomicCutOverSentryTableIfExists(); err != nil { return err diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 20475bf..442c730 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -350,7 +350,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.hooksExecutor.onBeforeCutOver(); err != nil { return err } - if err := this.cutOver(); err != nil { + if err := this.retryOperation(this.cutOver); err != nil { return err } atomic.StoreInt64(&this.migrationContext.CutOverCompleteFlag, 1) @@ -384,16 +384,18 @@ func (this *Migrator) cutOver() (err error) { }) this.migrationContext.MarkPointOfInterest() + log.Debugf("checking for cut-over postpone") this.sleepWhileTrue( func() (bool, error) { if this.migrationContext.PostponeCutOverFlagFile == "" { return false, nil } if atomic.LoadInt64(&this.migrationContext.UserCommandedUnpostponeFlag) > 0 { + atomic.StoreInt64(&this.migrationContext.UserCommandedUnpostponeFlag, 0) return false, nil } if base.FileExists(this.migrationContext.PostponeCutOverFlagFile) { - // Throttle file defined and exists! + // Postpone file defined and exists! if atomic.LoadInt64(&this.migrationContext.IsPostponingCutOver) == 0 { if err := this.hooksExecutor.onBeginPostponed(); err != nil { return true, err @@ -431,20 +433,11 @@ func (this *Migrator) cutOver() (err error) { if this.migrationContext.CutOverType == base.CutOverAtomic { // Atomic solution: we use low timeout and multiple attempts. But for // each failed attempt, we throttle until replication lag is back to normal - err := this.retryOperation( - func() error { - return this.executeAndThrottleOnError(this.atomicCutOver) - }, - ) + err := this.atomicCutOver() return err } if this.migrationContext.CutOverType == base.CutOverTwoStep { - err := this.retryOperation( - func() error { - return this.executeAndThrottleOnError(this.cutOverTwoStep) - }, - ) - return err + return this.cutOverTwoStep() } return log.Fatalf("Unknown cut-over type: %d; should never get here!", this.migrationContext.CutOverType) } @@ -452,6 +445,7 @@ func (this *Migrator) cutOver() (err error) { // Inject the "AllEventsUpToLockProcessed" state hint, wait for it to appear in the binary logs, // make sure the queue is drained. func (this *Migrator) waitForEventsUpToLock() (err error) { + // timeout := time.NewTimer(time.Minute * time.Duration(this.migrationContext.CutOverLockTimeoutSeconds)) timeout := time.NewTimer(time.Second * time.Duration(this.migrationContext.CutOverLockTimeoutSeconds)) this.migrationContext.MarkPointOfInterest() @@ -523,7 +517,9 @@ func (this *Migrator) atomicCutOver() (err error) { atomic.StoreInt64(&this.inCutOverCriticalActionFlag, 1) defer atomic.StoreInt64(&this.inCutOverCriticalActionFlag, 0) + okToUnlockTable := make(chan bool, 4) defer func() { + okToUnlockTable <- true this.applier.DropAtomicCutOverSentryTableIfExists() }() @@ -531,7 +527,6 @@ func (this *Migrator) atomicCutOver() (err error) { lockOriginalSessionIdChan := make(chan int64, 2) tableLocked := make(chan error, 2) - okToUnlockTable := make(chan bool, 3) tableUnlocked := make(chan error, 2) go func() { if err := this.applier.AtomicCutOverMagicLock(lockOriginalSessionIdChan, tableLocked, okToUnlockTable, tableUnlocked); err != nil { @@ -545,7 +540,9 @@ func (this *Migrator) atomicCutOver() (err error) { log.Infof("Session locking original & magic tables is %+v", lockOriginalSessionId) // At this point we know the original table is locked. // We know any newly incoming DML on original table is blocked. - this.waitForEventsUpToLock() + if err := this.waitForEventsUpToLock(); err != nil { + return log.Errore(err) + } // Step 2 // We now attempt an atomic RENAME on original & ghost tables, and expect it to block. From 7ab6af8f5fcaac824c26c16b1fbd60847dc9ebda Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 17 Nov 2016 17:22:13 +0100 Subject: [PATCH 031/104] never throttling inside cut-over critical section --- go/base/context.go | 11 +++++++++++ go/logic/migrator.go | 24 ++++++++---------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index f82fb22..21c5758 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -156,6 +156,7 @@ type MigrationContext struct { CleanupImminentFlag int64 UserCommandedUnpostponeFlag int64 CutOverCompleteFlag int64 + InCutOverCriticalSectionFlag int64 PanicAbort chan error OriginalTableColumnsOnApplier *sql.ColumnList @@ -438,6 +439,16 @@ func (this *MigrationContext) SetThrottled(throttle bool, reason string, reasonH func (this *MigrationContext) IsThrottled() (bool, string, ThrottleReasonHint) { this.throttleMutex.Lock() defer this.throttleMutex.Unlock() + + // we don't throttle when cutting over. We _do_ throttle: + // - during copy phase + // - just before cut-over + // - in between cut-over retries + // When cutting over, we need to be aggressive. Cut-over holds table locks. + // We need to release those asap. + if atomic.LoadInt64(&this.InCutOverCriticalSectionFlag) > 0 { + return false, "critical section", NoThrottleReasonHint + } return this.isThrottled, this.throttleReason, this.throttleReasonHint } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 442c730..cf59972 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -67,8 +67,7 @@ type Migrator struct { rowCopyComplete chan bool allEventsUpToLockProcessed chan string - rowCopyCompleteFlag int64 - inCutOverCriticalActionFlag int64 + rowCopyCompleteFlag int64 // copyRowsQueue should not be buffered; if buffered some non-damaging but // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete copyRowsQueue chan tableWriteFunc @@ -409,6 +408,7 @@ func (this *Migrator) cutOver() (err error) { ) atomic.StoreInt64(&this.migrationContext.IsPostponingCutOver, 0) this.migrationContext.MarkPointOfInterest() + log.Debugf("checking for cut-over postpone: complete") if this.migrationContext.TestOnReplica { // With `--test-on-replica` we stop replication thread, and then proceed to use @@ -445,7 +445,6 @@ func (this *Migrator) cutOver() (err error) { // Inject the "AllEventsUpToLockProcessed" state hint, wait for it to appear in the binary logs, // make sure the queue is drained. func (this *Migrator) waitForEventsUpToLock() (err error) { - // timeout := time.NewTimer(time.Minute * time.Duration(this.migrationContext.CutOverLockTimeoutSeconds)) timeout := time.NewTimer(time.Second * time.Duration(this.migrationContext.CutOverLockTimeoutSeconds)) this.migrationContext.MarkPointOfInterest() @@ -488,8 +487,8 @@ func (this *Migrator) waitForEventsUpToLock() (err error) { // There is a point in time where the "original" table does not exist and queries are non-blocked // and failing. func (this *Migrator) cutOverTwoStep() (err error) { - atomic.StoreInt64(&this.inCutOverCriticalActionFlag, 1) - defer atomic.StoreInt64(&this.inCutOverCriticalActionFlag, 0) + atomic.StoreInt64(&this.migrationContext.InCutOverCriticalSectionFlag, 1) + defer atomic.StoreInt64(&this.migrationContext.InCutOverCriticalSectionFlag, 0) atomic.StoreInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 0) if err := this.retryOperation(this.applier.LockOriginalTable); err != nil { @@ -514,8 +513,8 @@ func (this *Migrator) cutOverTwoStep() (err error) { // atomicCutOver func (this *Migrator) atomicCutOver() (err error) { - atomic.StoreInt64(&this.inCutOverCriticalActionFlag, 1) - defer atomic.StoreInt64(&this.inCutOverCriticalActionFlag, 0) + atomic.StoreInt64(&this.migrationContext.InCutOverCriticalSectionFlag, 1) + defer atomic.StoreInt64(&this.migrationContext.InCutOverCriticalSectionFlag, 0) okToUnlockTable := make(chan bool, 4) defer func() { @@ -1022,15 +1021,8 @@ func (this *Migrator) executeWriteFuncs() error { return nil } for { - if atomic.LoadInt64(&this.inCutOverCriticalActionFlag) == 0 { - // we don't throttle when cutting over. We _do_ throttle: - // - during copy phase - // - just before cut-over - // - in between cut-over retries - this.throttler.throttle(nil) - // When cutting over, we need to be aggressive. Cut-over holds table locks. - // We need to release those asap. - } + this.throttler.throttle(nil) + // We give higher priority to event processing, then secondary priority to // rowcopy select { From 7126b281691506818997ef7d06e91feed880c9f6 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 21 Nov 2016 09:18:40 +0100 Subject: [PATCH 032/104] support for --skip-foreign-key-checks --- doc/command-line-flags.md | 6 ++++++ go/base/context.go | 1 + go/cmd/gh-ost/main.go | 1 + go/logic/inspect.go | 4 ++++ 4 files changed, 12 insertions(+) diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index d51522c..1707685 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -63,6 +63,8 @@ Optional. Default is `safe`. See more discussion in [cut-over](cut-over.md) At this time (10-2016) `gh-ost` does not support foreign keys on migrated tables (it bails out when it notices a FK on the migrated table). However, it is able to support _dropping_ of foreign keys via this flag. If you're trying to get rid of foreign keys in your environment, this is a useful flag. +See also: [`skip-foreign-key-checks`](#skip-foreign-key-checks) + ### exact-rowcount A `gh-ost` execution need to copy whatever rows you have in your existing table onto the ghost table. This can, and often be, a large number. Exactly what that number is? @@ -109,6 +111,10 @@ See also: [Sub-second replication lag throttling](subsecond-lag.md) Typically `gh-ost` is used to migrate tables on a master. If you wish to only perform the migration in full on a replica, connect `gh-ost` to said replica and pass `--migrate-on-replica`. `gh-ost` will briefly connect to the master but other issue no changes on the master. Migration will be fully executed on the replica, while making sure to maintain a small replication lag. +### skip-foreign-key-checks + +By default `gh-ost` verifies no foreign keys exist on the migrated table. On servers with large number of tables this check can take a long time. If you're absolutely certain no foreign keys exist (table does not referenece other table nor is referenced by other tables) and wish to save the check time, provide with `--skip-foreign-key-checks`. + ### skip-renamed-columns See `approve-renamed-columns` diff --git a/go/base/context.go b/go/base/context.go index 21c5758..606cb69 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -75,6 +75,7 @@ type MigrationContext struct { AllowedMasterMaster bool SwitchToRowBinlogFormat bool AssumeRBR bool + SkipForeignKeyChecks bool NullableUniqueKeyAllowed bool ApproveRenamedColumns bool SkipRenamedColumns bool diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 781126d..5787b7a 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -62,6 +62,7 @@ func main() { flag.BoolVar(&migrationContext.SkipRenamedColumns, "skip-renamed-columns", false, "in case your `ALTER` statement renames columns, gh-ost will note that and offer its interpretation of the rename. By default gh-ost does not proceed to execute. This flag tells gh-ost to skip the renamed columns, i.e. to treat what gh-ost thinks are renamed columns as unrelated columns. NOTE: you may lose column data") flag.BoolVar(&migrationContext.IsTungsten, "tungsten", false, "explicitly let gh-ost know that you are running on a tungsten-replication based topology (you are likely to also provide --assume-master-host)") flag.BoolVar(&migrationContext.DiscardForeignKeys, "discard-foreign-keys", false, "DANGER! This flag will migrate a table that has foreign keys and will NOT create foreign keys on the ghost table, thus your altered table will have NO foreign keys. This is useful for intentional dropping of foreign keys") + flag.BoolVar(&migrationContext.SkipForeignKeyChecks, "skip-foreign-key-checks", false, "set to 'true' when you know for certain there are no foreign keys on your table, and wish to skip the time it takes for gh-ost to verify that") executeFlag := flag.Bool("execute", false, "actually execute the alter & migrate the table. Default is noop: do some tests and exit") flag.BoolVar(&migrationContext.TestOnReplica, "test-on-replica", false, "Have the migration run on a replica, not on the master. At the end of migration replication is stopped, and tables are swapped and immediately swap-revert. Replication remains stopped and you can compare the two tables for building trust") diff --git a/go/logic/inspect.go b/go/logic/inspect.go index fb77746..4519225 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -388,6 +388,10 @@ func (this *Inspector) validateTable() error { // validateTableForeignKeys makes sure no foreign keys exist on the migrated table func (this *Inspector) validateTableForeignKeys(allowChildForeignKeys bool) error { + if this.migrationContext.SkipForeignKeyChecks { + log.Warning("--skip-foreign-key-checks provided: will not check for foreign keys") + return nil + } query := ` SELECT SUM(REFERENCED_TABLE_NAME IS NOT NULL AND TABLE_SCHEMA=? AND TABLE_NAME=?) as num_child_side_fk, From 5119ea4d31ab5f380d8783f91a5fb2d046a8e622 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 29 Nov 2016 11:08:35 +0100 Subject: [PATCH 033/104] added tests to verify no false positives rename-column found --- go/logic/migrator.go | 2 +- localtests/rename-none-column/create.sql | 8 ++++++++ localtests/rename-none-column/extra_args | 1 + localtests/rename-none-comment/create.sql | 8 ++++++++ localtests/rename-none-comment/extra_args | 1 + 5 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 localtests/rename-none-column/create.sql create mode 100644 localtests/rename-none-column/extra_args create mode 100644 localtests/rename-none-comment/create.sql create mode 100644 localtests/rename-none-comment/extra_args diff --git a/go/logic/migrator.go b/go/logic/migrator.go index cf59972..8f61425 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -229,7 +229,7 @@ func (this *Migrator) validateStatement() (err error) { if this.parser.HasNonTrivialRenames() && !this.migrationContext.SkipRenamedColumns { this.migrationContext.ColumnRenameMap = this.parser.GetNonTrivialRenames() if !this.migrationContext.ApproveRenamedColumns { - return fmt.Errorf("gh-ost believes the ALTER statement renames columns, as follows: %v; as precation, you are asked to confirm gh-ost is correct, and provide with `--approve-renamed-columns`, and we're all happy. Or you can skip renamed columns via `--skip-renamed-columns`, in which case column data may be lost", this.parser.GetNonTrivialRenames()) + return fmt.Errorf("gh-ost believes the ALTER statement renames columns, as follows: %v; as precaution, you are asked to confirm gh-ost is correct, and provide with `--approve-renamed-columns`, and we're all happy. Or you can skip renamed columns via `--skip-renamed-columns`, in which case column data may be lost", this.parser.GetNonTrivialRenames()) } log.Infof("Alter statement has column(s) renamed. gh-ost finds the following renames: %v; --approve-renamed-columns is given and so migration proceeds.", this.parser.GetNonTrivialRenames()) } diff --git a/localtests/rename-none-column/create.sql b/localtests/rename-none-column/create.sql new file mode 100644 index 0000000..1db5081 --- /dev/null +++ b/localtests/rename-none-column/create.sql @@ -0,0 +1,8 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + c1 int not null, + primary key (id) +) auto_increment=1; + +drop event if exists gh_ost_test; diff --git a/localtests/rename-none-column/extra_args b/localtests/rename-none-column/extra_args new file mode 100644 index 0000000..889e91a --- /dev/null +++ b/localtests/rename-none-column/extra_args @@ -0,0 +1 @@ +--alter="add column exchange double comment 'exchange rate used for pay in your own currency'" diff --git a/localtests/rename-none-comment/create.sql b/localtests/rename-none-comment/create.sql new file mode 100644 index 0000000..1db5081 --- /dev/null +++ b/localtests/rename-none-comment/create.sql @@ -0,0 +1,8 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + c1 int not null, + primary key (id) +) auto_increment=1; + +drop event if exists gh_ost_test; diff --git a/localtests/rename-none-comment/extra_args b/localtests/rename-none-comment/extra_args new file mode 100644 index 0000000..627b220 --- /dev/null +++ b/localtests/rename-none-comment/extra_args @@ -0,0 +1 @@ +--alter="add column exchange_rate double comment 'change rate used for pay in your own currency'" From e7cf488818803470f596906b7748d9654c0a267b Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 29 Nov 2016 15:44:38 +0100 Subject: [PATCH 034/104] fixed parsing of quotes and of detecting rename statements --- go/sql/parser.go | 59 +++++++++++++++++++++++++++++++++++-------- go/sql/parser_test.go | 51 +++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 10 deletions(-) diff --git a/go/sql/parser.go b/go/sql/parser.go index 144f265..efdbb74 100644 --- a/go/sql/parser.go +++ b/go/sql/parser.go @@ -8,10 +8,12 @@ package sql import ( "regexp" "strconv" + "strings" ) var ( - renameColumnRegexp = regexp.MustCompile(`(?i)change\s+(column\s+|)([\S]+)\s+([\S]+)\s+`) + stripQuotesRegexp = regexp.MustCompile("('[^']*')") + renameColumnRegexp = regexp.MustCompile(`(?i)\bchange\s+(column\s+|)([\S]+)\s+([\S]+)\s+`) ) type Parser struct { @@ -24,17 +26,54 @@ func NewParser() *Parser { } } -func (this *Parser) ParseAlterStatement(alterStatement string) (err error) { - allStringSubmatch := renameColumnRegexp.FindAllStringSubmatch(alterStatement, -1) - for _, submatch := range allStringSubmatch { - if unquoted, err := strconv.Unquote(submatch[2]); err == nil { - submatch[2] = unquoted - } - if unquoted, err := strconv.Unquote(submatch[3]); err == nil { - submatch[3] = unquoted +func (this *Parser) tokenizeAlterStatement(alterStatement string) (tokens []string, err error) { + terminatingQuote := rune(0) + f := func(c rune) bool { + switch { + case c == terminatingQuote: + terminatingQuote = rune(0) + return false + case terminatingQuote != rune(0): + return false + case c == '\'': + terminatingQuote = c + return false + case c == '(': + terminatingQuote = ')' + return false + default: + return c == ',' } + } - this.columnRenameMap[submatch[2]] = submatch[3] + tokens = strings.FieldsFunc(alterStatement, f) + for i := range tokens { + tokens[i] = strings.TrimSpace(tokens[i]) + } + return tokens, nil +} + +func (this *Parser) stripQuotesFromAlterStatement(alterStatement string) (strippedStatement string) { + strippedStatement = alterStatement + strippedStatement = stripQuotesRegexp.ReplaceAllString(strippedStatement, "''") + return strippedStatement +} + +func (this *Parser) ParseAlterStatement(alterStatement string) (err error) { + alterTokens, _ := this.tokenizeAlterStatement(alterStatement) + for _, alterToken := range alterTokens { + alterToken = this.stripQuotesFromAlterStatement(alterToken) + allStringSubmatch := renameColumnRegexp.FindAllStringSubmatch(alterToken, -1) + for _, submatch := range allStringSubmatch { + if unquoted, err := strconv.Unquote(submatch[2]); err == nil { + submatch[2] = unquoted + } + if unquoted, err := strconv.Unquote(submatch[3]); err == nil { + submatch[3] = unquoted + } + + this.columnRenameMap[submatch[2]] = submatch[3] + } } return nil } diff --git a/go/sql/parser_test.go b/go/sql/parser_test.go index 2107963..a4d4075 100644 --- a/go/sql/parser_test.go +++ b/go/sql/parser_test.go @@ -6,6 +6,7 @@ package sql import ( + "reflect" "testing" "github.com/outbrain/golib/log" @@ -66,3 +67,53 @@ func TestParseAlterStatementNonTrivial(t *testing.T) { test.S(t).ExpectEquals(renames["f"], "fl") } } + +func TestTokenizeAlterStatement(t *testing.T) { + parser := NewParser() + { + alterStatement := "add column t int" + tokens, _ := parser.tokenizeAlterStatement(alterStatement) + test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int"})) + } + { + alterStatement := "add column t int, change column i int" + tokens, _ := parser.tokenizeAlterStatement(alterStatement) + test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "change column i int"})) + } + { + alterStatement := "add column t int, change column i int 'some comment'" + tokens, _ := parser.tokenizeAlterStatement(alterStatement) + test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "change column i int 'some comment'"})) + } + { + alterStatement := "add column t int, change column i int 'some comment, with comma'" + tokens, _ := parser.tokenizeAlterStatement(alterStatement) + test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "change column i int 'some comment, with comma'"})) + } + { + alterStatement := "add column t int, add column e enum('a','b','c')" + tokens, _ := parser.tokenizeAlterStatement(alterStatement) + log.Errorf("%#v", tokens) + test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "add column e enum('a','b','c')"})) + } + { + alterStatement := "add column t int(11), add column e enum('a','b','c')" + tokens, _ := parser.tokenizeAlterStatement(alterStatement) + log.Errorf("%#v", tokens) + test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int(11)", "add column e enum('a','b','c')"})) + } +} + +func TestStripQuotesFromAlterStatement(t *testing.T) { + parser := NewParser() + { + alterStatement := "add column e enum('a','b','c')" + strippedStatement := parser.stripQuotesFromAlterStatement(alterStatement) + test.S(t).ExpectEquals(strippedStatement, "add column e enum('','','')") + } + { + alterStatement := "change column i int 'some comment, with comma'" + strippedStatement := parser.stripQuotesFromAlterStatement(alterStatement) + test.S(t).ExpectEquals(strippedStatement, "change column i int ''") + } +} From 0a707688e0c6095938bed4902cd314588ce690f8 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 29 Nov 2016 15:47:39 +0100 Subject: [PATCH 035/104] added decimal test --- go/sql/parser.go | 10 +++++----- go/sql/parser_test.go | 11 ++++++++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/go/sql/parser.go b/go/sql/parser.go index efdbb74..b81b31c 100644 --- a/go/sql/parser.go +++ b/go/sql/parser.go @@ -12,8 +12,8 @@ import ( ) var ( - stripQuotesRegexp = regexp.MustCompile("('[^']*')") - renameColumnRegexp = regexp.MustCompile(`(?i)\bchange\s+(column\s+|)([\S]+)\s+([\S]+)\s+`) + sanitizeQuotesRegexp = regexp.MustCompile("('[^']*')") + renameColumnRegexp = regexp.MustCompile(`(?i)\bchange\s+(column\s+|)([\S]+)\s+([\S]+)\s+`) ) type Parser struct { @@ -53,16 +53,16 @@ func (this *Parser) tokenizeAlterStatement(alterStatement string) (tokens []stri return tokens, nil } -func (this *Parser) stripQuotesFromAlterStatement(alterStatement string) (strippedStatement string) { +func (this *Parser) sanitizeQuotesFromAlterStatement(alterStatement string) (strippedStatement string) { strippedStatement = alterStatement - strippedStatement = stripQuotesRegexp.ReplaceAllString(strippedStatement, "''") + strippedStatement = sanitizeQuotesRegexp.ReplaceAllString(strippedStatement, "''") return strippedStatement } func (this *Parser) ParseAlterStatement(alterStatement string) (err error) { alterTokens, _ := this.tokenizeAlterStatement(alterStatement) for _, alterToken := range alterTokens { - alterToken = this.stripQuotesFromAlterStatement(alterToken) + alterToken = this.sanitizeQuotesFromAlterStatement(alterToken) allStringSubmatch := renameColumnRegexp.FindAllStringSubmatch(alterToken, -1) for _, submatch := range allStringSubmatch { if unquoted, err := strconv.Unquote(submatch[2]); err == nil { diff --git a/go/sql/parser_test.go b/go/sql/parser_test.go index a4d4075..876429a 100644 --- a/go/sql/parser_test.go +++ b/go/sql/parser_test.go @@ -90,6 +90,11 @@ func TestTokenizeAlterStatement(t *testing.T) { tokens, _ := parser.tokenizeAlterStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "change column i int 'some comment, with comma'"})) } + { + alterStatement := "add column t int, add column d decimal(10,2)" + tokens, _ := parser.tokenizeAlterStatement(alterStatement) + test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "add column d decimal(10,2)"})) + } { alterStatement := "add column t int, add column e enum('a','b','c')" tokens, _ := parser.tokenizeAlterStatement(alterStatement) @@ -104,16 +109,16 @@ func TestTokenizeAlterStatement(t *testing.T) { } } -func TestStripQuotesFromAlterStatement(t *testing.T) { +func TestSanitizeQuotesFromAlterStatement(t *testing.T) { parser := NewParser() { alterStatement := "add column e enum('a','b','c')" - strippedStatement := parser.stripQuotesFromAlterStatement(alterStatement) + strippedStatement := parser.sanitizeQuotesFromAlterStatement(alterStatement) test.S(t).ExpectEquals(strippedStatement, "add column e enum('','','')") } { alterStatement := "change column i int 'some comment, with comma'" - strippedStatement := parser.stripQuotesFromAlterStatement(alterStatement) + strippedStatement := parser.sanitizeQuotesFromAlterStatement(alterStatement) test.S(t).ExpectEquals(strippedStatement, "change column i int ''") } } From a11bec178518d16907e24a7baf83c881a78ef103 Mon Sep 17 00:00:00 2001 From: rj03hou Date: Thu, 1 Dec 2016 16:04:04 +0800 Subject: [PATCH 036/104] If the original table is MyISAM and the default engine is Innodb, and the gtid mode is on, there will be error when execute 'LOCK TABLES tbl WRITE, tbl_magic WRITE'. If make the magic cut-over table's engine same with the original table, there will not have this problem. --- go/logic/applier.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go/logic/applier.go b/go/logic/applier.go index 286728a..b9568b9 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -701,8 +701,9 @@ func (this *Applier) CreateAtomicCutOverSentryTable() error { query := fmt.Sprintf(`create /* gh-ost */ table %s.%s ( id int auto_increment primary key - ) comment='%s' + ) engine=%s comment='%s' `, + this.migrationContext.TableEngine, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(tableName), atomicCutOverMagicHint, From ffbd35e18083293fce248ddeb30261022c9c777e Mon Sep 17 00:00:00 2001 From: rj03hou Date: Fri, 2 Dec 2016 11:56:29 +0800 Subject: [PATCH 037/104] fix TableEngine correlates to the 3rd placeholder in the template string, not the 1st --- go/logic/applier.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/logic/applier.go b/go/logic/applier.go index 16d6e51..afff578 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -705,9 +705,9 @@ func (this *Applier) CreateAtomicCutOverSentryTable() error { id int auto_increment primary key ) engine=%s comment='%s' `, - this.migrationContext.TableEngine, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(tableName), + this.migrationContext.TableEngine, atomicCutOverMagicHint, ) log.Infof("Creating magic cut-over table %s.%s", From 8f02ab0fedb15d3aa57a699e08bd35f7d429c459 Mon Sep 17 00:00:00 2001 From: rj03hou Date: Mon, 5 Dec 2016 19:42:16 +0800 Subject: [PATCH 038/104] check the slave status when find recursive find the master, so support if the dba using reset slave instead of reset slave all. --- go/logic/inspect.go | 1 + go/mysql/utils.go | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 4519225..9b1ff5e 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -692,6 +692,7 @@ func (this *Inspector) readChangelogState() (map[string]string, error) { } func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { + log.Infof("Start find master, recursive to execute show slave status from the host to find the master.") visitedKeys := mysql.NewInstanceKeyMap() return mysql.GetMasterConnectionConfigSafe(this.connectionConfig, visitedKeys, this.migrationContext.AllowedMasterMaster) } diff --git a/go/mysql/utils.go b/go/mysql/utils.go index ae22e3a..de12e46 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -83,12 +83,30 @@ func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey return nil, err } err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { + // Using reset slave not reset slave all, the Master_Log_File is empty. + if rowMap.GetString("Master_Log_File") == "" { + return nil + } + + slaveIORunning := rowMap.GetString("Slave_IO_Running") + slaveSQLRunning := rowMap.GetString("Slave_SQL_Running") + + // If not using reset slave or reset slave all and slave is break, something is wrong, so we stop. + if slaveIORunning != "Yes" || slaveSQLRunning != "Yes" { + return fmt.Errorf("The slave %s is break: Slave_IO_Running:%s Slave_SQL_Running:%s please make sure the replication is correct.", + connectionConfig.Key.Hostname, + slaveIORunning, + slaveSQLRunning, + ) + } + masterKey = &InstanceKey{ Hostname: rowMap.GetString("Master_Host"), Port: rowMap.GetInt("Master_Port"), } return nil }) + return masterKey, err } From e157f692d58cabd1c186aa7397f617ebb565a076 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 5 Dec 2016 13:41:31 +0100 Subject: [PATCH 039/104] added 'trivial' test --- localtests/trivial/create.sql | 13 +++++++++++++ localtests/trivial/extra_args | 1 + 2 files changed, 14 insertions(+) create mode 100644 localtests/trivial/create.sql create mode 100644 localtests/trivial/extra_args diff --git a/localtests/trivial/create.sql b/localtests/trivial/create.sql new file mode 100644 index 0000000..9371238 --- /dev/null +++ b/localtests/trivial/create.sql @@ -0,0 +1,13 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + color varchar(32), + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; + +insert into gh_ost_test values (null, 11, 'red'); +insert into gh_ost_test values (null, 13, 'green'); +insert into gh_ost_test values (null, 17, 'blue'); diff --git a/localtests/trivial/extra_args b/localtests/trivial/extra_args new file mode 100644 index 0000000..8b6320a --- /dev/null +++ b/localtests/trivial/extra_args @@ -0,0 +1 @@ +--throttle-query='select false' \ From 35eeb560326e127ef1615cdf057f1bb145a226d8 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 5 Dec 2016 13:41:49 +0100 Subject: [PATCH 040/104] improved log/error messages --- go/logic/inspect.go | 2 +- go/mysql/utils.go | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 9b1ff5e..40921cc 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -692,7 +692,7 @@ func (this *Inspector) readChangelogState() (map[string]string, error) { } func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { - log.Infof("Start find master, recursive to execute show slave status from the host to find the master.") + log.Infof("Recursively searching for replication master") visitedKeys := mysql.NewInstanceKeyMap() return mysql.GetMasterConnectionConfigSafe(this.connectionConfig, visitedKeys, this.migrationContext.AllowedMasterMaster) } diff --git a/go/mysql/utils.go b/go/mysql/utils.go index de12e46..6b70c7c 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -83,7 +83,10 @@ func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey return nil, err } err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { - // Using reset slave not reset slave all, the Master_Log_File is empty. + // We wish to recognize the case where the topology's master actually has replication configuration. + // This can happen when a DBA issues a `RESET SLAVE` instead of `RESET SLAVE ALL`. + + // An empty log file indicates this is a master: if rowMap.GetString("Master_Log_File") == "" { return nil } @@ -91,10 +94,10 @@ func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey slaveIORunning := rowMap.GetString("Slave_IO_Running") slaveSQLRunning := rowMap.GetString("Slave_SQL_Running") - // If not using reset slave or reset slave all and slave is break, something is wrong, so we stop. + // if slaveIORunning != "Yes" || slaveSQLRunning != "Yes" { - return fmt.Errorf("The slave %s is break: Slave_IO_Running:%s Slave_SQL_Running:%s please make sure the replication is correct.", - connectionConfig.Key.Hostname, + return fmt.Errorf("Replication on %+v is broken: Slave_IO_Running: %s, Slave_SQL_Running: %s. Please make sure replication runs before using gh-ost.", + connectionConfig.Key, slaveIORunning, slaveSQLRunning, ) From c9edcd6f2020004810f2856f95f49d54002f584e Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 6 Dec 2016 13:11:56 +0100 Subject: [PATCH 041/104] updated version --- RELEASE_VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASE_VERSION b/RELEASE_VERSION index 8b54409..475bda9 100644 --- a/RELEASE_VERSION +++ b/RELEASE_VERSION @@ -1 +1 @@ -1.0.28 +1.0.30 From 1d84cb933c859370a182c7f58f42473ceedafb1b Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 11 Dec 2016 02:19:19 +0100 Subject: [PATCH 042/104] fix: bailing out on no PRIMARY/UNIQUE KEY --- go/logic/inspect.go | 2 +- localtests/no-unique-key/create.sql | 9 +++++++++ localtests/no-unique-key/expect_failure | 1 + localtests/no-unique-key/extra_args | 1 + 4 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 localtests/no-unique-key/create.sql create mode 100644 localtests/no-unique-key/expect_failure create mode 100644 localtests/no-unique-key/extra_args diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 40921cc..cd47c78 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -95,7 +95,7 @@ func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (colum func (this *Inspector) InspectOriginalTable() (err error) { this.migrationContext.OriginalTableColumns, this.migrationContext.OriginalTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.OriginalTableName) - if err == nil { + if err != nil { return err } return nil diff --git a/localtests/no-unique-key/create.sql b/localtests/no-unique-key/create.sql new file mode 100644 index 0000000..e854c3e --- /dev/null +++ b/localtests/no-unique-key/create.sql @@ -0,0 +1,9 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + i int not null, + ts timestamp default current_timestamp, + dt datetime, + key i_idx(i) +) auto_increment=1; + +drop event if exists gh_ost_test; diff --git a/localtests/no-unique-key/expect_failure b/localtests/no-unique-key/expect_failure new file mode 100644 index 0000000..021ae87 --- /dev/null +++ b/localtests/no-unique-key/expect_failure @@ -0,0 +1 @@ +No PRIMARY nor UNIQUE key found in table diff --git a/localtests/no-unique-key/extra_args b/localtests/no-unique-key/extra_args new file mode 100644 index 0000000..b13e72c --- /dev/null +++ b/localtests/no-unique-key/extra_args @@ -0,0 +1 @@ +--alter="add column v varchar(32)" From 3fd85ee8b12a32e26d1b63e2197c0264ea22b3cb Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 11 Dec 2016 02:22:56 +0100 Subject: [PATCH 043/104] test logging cleanup --- go/sql/parser_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/go/sql/parser_test.go b/go/sql/parser_test.go index 876429a..8039f5f 100644 --- a/go/sql/parser_test.go +++ b/go/sql/parser_test.go @@ -98,13 +98,11 @@ func TestTokenizeAlterStatement(t *testing.T) { { alterStatement := "add column t int, add column e enum('a','b','c')" tokens, _ := parser.tokenizeAlterStatement(alterStatement) - log.Errorf("%#v", tokens) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "add column e enum('a','b','c')"})) } { alterStatement := "add column t int(11), add column e enum('a','b','c')" tokens, _ := parser.tokenizeAlterStatement(alterStatement) - log.Errorf("%#v", tokens) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int(11)", "add column e enum('a','b','c')"})) } } From 7259dd6ac5519b45b866680bbae2bff0ccf9bd19 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 11 Dec 2016 02:55:21 +0100 Subject: [PATCH 044/104] supporting --ask-pass option --- go/cmd/gh-ost/main.go | 11 + .../golang.org/x/crypto/ssh/agent/client.go | 659 +++++++++++++ .../x/crypto/ssh/agent/client_test.go | 343 +++++++ .../x/crypto/ssh/agent/example_test.go | 40 + .../golang.org/x/crypto/ssh/agent/forward.go | 103 ++ .../golang.org/x/crypto/ssh/agent/keyring.go | 215 ++++ .../x/crypto/ssh/agent/keyring_test.go | 76 ++ .../golang.org/x/crypto/ssh/agent/server.go | 451 +++++++++ .../x/crypto/ssh/agent/server_test.go | 207 ++++ .../x/crypto/ssh/agent/testdata_test.go | 64 ++ .../golang.org/x/crypto/ssh/benchmark_test.go | 122 +++ vendor/golang.org/x/crypto/ssh/buffer.go | 98 ++ vendor/golang.org/x/crypto/ssh/buffer_test.go | 87 ++ vendor/golang.org/x/crypto/ssh/certs.go | 503 ++++++++++ vendor/golang.org/x/crypto/ssh/certs_test.go | 216 ++++ vendor/golang.org/x/crypto/ssh/channel.go | 633 ++++++++++++ vendor/golang.org/x/crypto/ssh/cipher.go | 579 +++++++++++ vendor/golang.org/x/crypto/ssh/cipher_test.go | 127 +++ vendor/golang.org/x/crypto/ssh/client.go | 213 ++++ vendor/golang.org/x/crypto/ssh/client_auth.go | 473 +++++++++ .../x/crypto/ssh/client_auth_test.go | 472 +++++++++ vendor/golang.org/x/crypto/ssh/client_test.go | 39 + vendor/golang.org/x/crypto/ssh/common.go | 356 +++++++ vendor/golang.org/x/crypto/ssh/connection.go | 143 +++ vendor/golang.org/x/crypto/ssh/doc.go | 18 + .../golang.org/x/crypto/ssh/example_test.go | 262 +++++ vendor/golang.org/x/crypto/ssh/handshake.go | 460 +++++++++ .../golang.org/x/crypto/ssh/handshake_test.go | 486 +++++++++ vendor/golang.org/x/crypto/ssh/kex.go | 540 ++++++++++ vendor/golang.org/x/crypto/ssh/kex_test.go | 50 + vendor/golang.org/x/crypto/ssh/keys.go | 905 +++++++++++++++++ vendor/golang.org/x/crypto/ssh/keys_test.go | 474 +++++++++ vendor/golang.org/x/crypto/ssh/mac.go | 57 ++ .../golang.org/x/crypto/ssh/mempipe_test.go | 110 +++ vendor/golang.org/x/crypto/ssh/messages.go | 758 ++++++++++++++ .../golang.org/x/crypto/ssh/messages_test.go | 288 ++++++ vendor/golang.org/x/crypto/ssh/mux.go | 330 +++++++ vendor/golang.org/x/crypto/ssh/mux_test.go | 502 ++++++++++ vendor/golang.org/x/crypto/ssh/server.go | 488 +++++++++ vendor/golang.org/x/crypto/ssh/session.go | 627 ++++++++++++ .../golang.org/x/crypto/ssh/session_test.go | 770 +++++++++++++++ vendor/golang.org/x/crypto/ssh/tcpip.go | 407 ++++++++ vendor/golang.org/x/crypto/ssh/tcpip_test.go | 20 + .../x/crypto/ssh/terminal/terminal.go | 924 ++++++++++++++++++ .../x/crypto/ssh/terminal/terminal_test.go | 306 ++++++ .../golang.org/x/crypto/ssh/terminal/util.go | 133 +++ .../x/crypto/ssh/terminal/util_bsd.go | 12 + .../x/crypto/ssh/terminal/util_linux.go | 11 + .../x/crypto/ssh/terminal/util_plan9.go | 58 ++ .../x/crypto/ssh/terminal/util_solaris.go | 73 ++ .../x/crypto/ssh/terminal/util_windows.go | 174 ++++ .../x/crypto/ssh/test/agent_unix_test.go | 59 ++ .../golang.org/x/crypto/ssh/test/cert_test.go | 47 + vendor/golang.org/x/crypto/ssh/test/doc.go | 7 + .../x/crypto/ssh/test/forward_unix_test.go | 160 +++ .../x/crypto/ssh/test/session_test.go | 365 +++++++ .../x/crypto/ssh/test/tcpip_test.go | 46 + .../x/crypto/ssh/test/test_unix_test.go | 268 +++++ .../x/crypto/ssh/test/testdata_test.go | 64 ++ .../golang.org/x/crypto/ssh/testdata/doc.go | 8 + .../golang.org/x/crypto/ssh/testdata/keys.go | 120 +++ .../golang.org/x/crypto/ssh/testdata_test.go | 63 ++ vendor/golang.org/x/crypto/ssh/transport.go | 333 +++++++ .../golang.org/x/crypto/ssh/transport_test.go | 109 +++ 64 files changed, 17122 insertions(+) create mode 100644 vendor/golang.org/x/crypto/ssh/agent/client.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/client_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/example_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/forward.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/keyring.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/keyring_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/server.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/server_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/agent/testdata_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/benchmark_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/buffer.go create mode 100644 vendor/golang.org/x/crypto/ssh/buffer_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/certs.go create mode 100644 vendor/golang.org/x/crypto/ssh/certs_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/channel.go create mode 100644 vendor/golang.org/x/crypto/ssh/cipher.go create mode 100644 vendor/golang.org/x/crypto/ssh/cipher_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/client.go create mode 100644 vendor/golang.org/x/crypto/ssh/client_auth.go create mode 100644 vendor/golang.org/x/crypto/ssh/client_auth_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/client_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/common.go create mode 100644 vendor/golang.org/x/crypto/ssh/connection.go create mode 100644 vendor/golang.org/x/crypto/ssh/doc.go create mode 100644 vendor/golang.org/x/crypto/ssh/example_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/handshake.go create mode 100644 vendor/golang.org/x/crypto/ssh/handshake_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/kex.go create mode 100644 vendor/golang.org/x/crypto/ssh/kex_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/keys.go create mode 100644 vendor/golang.org/x/crypto/ssh/keys_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/mac.go create mode 100644 vendor/golang.org/x/crypto/ssh/mempipe_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/messages.go create mode 100644 vendor/golang.org/x/crypto/ssh/messages_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/mux.go create mode 100644 vendor/golang.org/x/crypto/ssh/mux_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/server.go create mode 100644 vendor/golang.org/x/crypto/ssh/session.go create mode 100644 vendor/golang.org/x/crypto/ssh/session_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/tcpip.go create mode 100644 vendor/golang.org/x/crypto/ssh/tcpip_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/terminal.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/terminal_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/util.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/util_bsd.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/util_linux.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/util_plan9.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/util_solaris.go create mode 100644 vendor/golang.org/x/crypto/ssh/terminal/util_windows.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/agent_unix_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/cert_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/doc.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/forward_unix_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/session_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/tcpip_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/test_unix_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/testdata_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/testdata/doc.go create mode 100644 vendor/golang.org/x/crypto/ssh/testdata/keys.go create mode 100644 vendor/golang.org/x/crypto/ssh/testdata_test.go create mode 100644 vendor/golang.org/x/crypto/ssh/transport.go create mode 100644 vendor/golang.org/x/crypto/ssh/transport_test.go diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 5787b7a..63bb230 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -15,6 +15,8 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/logic" "github.com/outbrain/golib/log" + + "golang.org/x/crypto/ssh/terminal" ) var AppVersion string @@ -49,6 +51,7 @@ func main() { flag.StringVar(&migrationContext.CliUser, "user", "", "MySQL user") flag.StringVar(&migrationContext.CliPassword, "password", "", "MySQL password") flag.StringVar(&migrationContext.ConfigFile, "conf", "", "Config file") + askPass := flag.Bool("ask-pass", false, "prompt for MySQL password") flag.StringVar(&migrationContext.DatabaseName, "database", "", "database name (mandatory)") flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)") @@ -194,6 +197,14 @@ func main() { if migrationContext.ServeSocketFile == "" { migrationContext.ServeSocketFile = fmt.Sprintf("/tmp/gh-ost.%s.%s.sock", migrationContext.DatabaseName, migrationContext.OriginalTableName) } + if *askPass { + fmt.Println("Password:") + bytePassword, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + log.Fatale(err) + } + migrationContext.CliPassword = string(bytePassword) + } migrationContext.SetHeartbeatIntervalMilliseconds(*heartbeatIntervalMillis) migrationContext.SetNiceRatio(*niceRatio) migrationContext.SetChunkSize(*chunkSize) diff --git a/vendor/golang.org/x/crypto/ssh/agent/client.go b/vendor/golang.org/x/crypto/ssh/agent/client.go new file mode 100644 index 0000000..ecfd7c5 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/client.go @@ -0,0 +1,659 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package agent implements the ssh-agent protocol, and provides both +// a client and a server. The client can talk to a standard ssh-agent +// that uses UNIX sockets, and one could implement an alternative +// ssh-agent process using the sample server. +// +// References: +// [PROTOCOL.agent]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent?rev=HEAD +package agent // import "golang.org/x/crypto/ssh/agent" + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "sync" + + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" +) + +// Agent represents the capabilities of an ssh-agent. +type Agent interface { + // List returns the identities known to the agent. + List() ([]*Key, error) + + // Sign has the agent sign the data using a protocol 2 key as defined + // in [PROTOCOL.agent] section 2.6.2. + Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) + + // Add adds a private key to the agent. + Add(key AddedKey) error + + // Remove removes all identities with the given public key. + Remove(key ssh.PublicKey) error + + // RemoveAll removes all identities. + RemoveAll() error + + // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. + Lock(passphrase []byte) error + + // Unlock undoes the effect of Lock + Unlock(passphrase []byte) error + + // Signers returns signers for all the known keys. + Signers() ([]ssh.Signer, error) +} + +// AddedKey describes an SSH key to be added to an Agent. +type AddedKey struct { + // PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or + // *ecdsa.PrivateKey, which will be inserted into the agent. + PrivateKey interface{} + // Certificate, if not nil, is communicated to the agent and will be + // stored with the key. + Certificate *ssh.Certificate + // Comment is an optional, free-form string. + Comment string + // LifetimeSecs, if not zero, is the number of seconds that the + // agent will store the key for. + LifetimeSecs uint32 + // ConfirmBeforeUse, if true, requests that the agent confirm with the + // user before each use of this key. + ConfirmBeforeUse bool +} + +// See [PROTOCOL.agent], section 3. +const ( + agentRequestV1Identities = 1 + agentRemoveAllV1Identities = 9 + + // 3.2 Requests from client to agent for protocol 2 key operations + agentAddIdentity = 17 + agentRemoveIdentity = 18 + agentRemoveAllIdentities = 19 + agentAddIdConstrained = 25 + + // 3.3 Key-type independent requests from client to agent + agentAddSmartcardKey = 20 + agentRemoveSmartcardKey = 21 + agentLock = 22 + agentUnlock = 23 + agentAddSmartcardKeyConstrained = 26 + + // 3.7 Key constraint identifiers + agentConstrainLifetime = 1 + agentConstrainConfirm = 2 +) + +// maxAgentResponseBytes is the maximum agent reply size that is accepted. This +// is a sanity check, not a limit in the spec. +const maxAgentResponseBytes = 16 << 20 + +// Agent messages: +// These structures mirror the wire format of the corresponding ssh agent +// messages found in [PROTOCOL.agent]. + +// 3.4 Generic replies from agent to client +const agentFailure = 5 + +type failureAgentMsg struct{} + +const agentSuccess = 6 + +type successAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentRequestIdentities = 11 + +type requestIdentitiesAgentMsg struct{} + +// See [PROTOCOL.agent], section 2.5.2. +const agentIdentitiesAnswer = 12 + +type identitiesAnswerAgentMsg struct { + NumKeys uint32 `sshtype:"12"` + Keys []byte `ssh:"rest"` +} + +// See [PROTOCOL.agent], section 2.6.2. +const agentSignRequest = 13 + +type signRequestAgentMsg struct { + KeyBlob []byte `sshtype:"13"` + Data []byte + Flags uint32 +} + +// See [PROTOCOL.agent], section 2.6.2. + +// 3.6 Replies from agent to client for protocol 2 key operations +const agentSignResponse = 14 + +type signResponseAgentMsg struct { + SigBlob []byte `sshtype:"14"` +} + +type publicKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +// Key represents a protocol 2 public key as defined in +// [PROTOCOL.agent], section 2.5.2. +type Key struct { + Format string + Blob []byte + Comment string +} + +func clientErr(err error) error { + return fmt.Errorf("agent: client error: %v", err) +} + +// String returns the storage form of an agent key with the format, base64 +// encoded serialized key, and the comment if it is not empty. +func (k *Key) String() string { + s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) + + if k.Comment != "" { + s += " " + k.Comment + } + + return s +} + +// Type returns the public key type. +func (k *Key) Type() string { + return k.Format +} + +// Marshal returns key blob to satisfy the ssh.PublicKey interface. +func (k *Key) Marshal() []byte { + return k.Blob +} + +// Verify satisfies the ssh.PublicKey interface. +func (k *Key) Verify(data []byte, sig *ssh.Signature) error { + pubKey, err := ssh.ParsePublicKey(k.Blob) + if err != nil { + return fmt.Errorf("agent: bad public key: %v", err) + } + return pubKey.Verify(data, sig) +} + +type wireKey struct { + Format string + Rest []byte `ssh:"rest"` +} + +func parseKey(in []byte) (out *Key, rest []byte, err error) { + var record struct { + Blob []byte + Comment string + Rest []byte `ssh:"rest"` + } + + if err := ssh.Unmarshal(in, &record); err != nil { + return nil, nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(record.Blob, &wk); err != nil { + return nil, nil, err + } + + return &Key{ + Format: wk.Format, + Blob: record.Blob, + Comment: record.Comment, + }, record.Rest, nil +} + +// client is a client for an ssh-agent process. +type client struct { + // conn is typically a *net.UnixConn + conn io.ReadWriter + // mu is used to prevent concurrent access to the agent + mu sync.Mutex +} + +// NewClient returns an Agent that talks to an ssh-agent process over +// the given connection. +func NewClient(rw io.ReadWriter) Agent { + return &client{conn: rw} +} + +// call sends an RPC to the agent. On success, the reply is +// unmarshaled into reply and replyType is set to the first byte of +// the reply, which contains the type of the message. +func (c *client) call(req []byte) (reply interface{}, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + msg := make([]byte, 4+len(req)) + binary.BigEndian.PutUint32(msg, uint32(len(req))) + copy(msg[4:], req) + if _, err = c.conn.Write(msg); err != nil { + return nil, clientErr(err) + } + + var respSizeBuf [4]byte + if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { + return nil, clientErr(err) + } + respSize := binary.BigEndian.Uint32(respSizeBuf[:]) + if respSize > maxAgentResponseBytes { + return nil, clientErr(err) + } + + buf := make([]byte, respSize) + if _, err = io.ReadFull(c.conn, buf); err != nil { + return nil, clientErr(err) + } + reply, err = unmarshal(buf) + if err != nil { + return nil, clientErr(err) + } + return reply, err +} + +func (c *client) simpleCall(req []byte) error { + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +func (c *client) RemoveAll() error { + return c.simpleCall([]byte{agentRemoveAllIdentities}) +} + +func (c *client) Remove(key ssh.PublicKey) error { + req := ssh.Marshal(&agentRemoveIdentityMsg{ + KeyBlob: key.Marshal(), + }) + return c.simpleCall(req) +} + +func (c *client) Lock(passphrase []byte) error { + req := ssh.Marshal(&agentLockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +func (c *client) Unlock(passphrase []byte) error { + req := ssh.Marshal(&agentUnlockMsg{ + Passphrase: passphrase, + }) + return c.simpleCall(req) +} + +// List returns the identities known to the agent. +func (c *client) List() ([]*Key, error) { + // see [PROTOCOL.agent] section 2.5.2. + req := []byte{agentRequestIdentities} + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *identitiesAnswerAgentMsg: + if msg.NumKeys > maxAgentResponseBytes/8 { + return nil, errors.New("agent: too many keys in agent reply") + } + keys := make([]*Key, msg.NumKeys) + data := msg.Keys + for i := uint32(0); i < msg.NumKeys; i++ { + var key *Key + var err error + if key, data, err = parseKey(data); err != nil { + return nil, err + } + keys[i] = key + } + return keys, nil + case *failureAgentMsg: + return nil, errors.New("agent: failed to list keys") + } + panic("unreachable") +} + +// Sign has the agent sign the data using a protocol 2 key as defined +// in [PROTOCOL.agent] section 2.6.2. +func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + req := ssh.Marshal(signRequestAgentMsg{ + KeyBlob: key.Marshal(), + Data: data, + }) + + msg, err := c.call(req) + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *signResponseAgentMsg: + var sig ssh.Signature + if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil { + return nil, err + } + + return &sig, nil + case *failureAgentMsg: + return nil, errors.New("agent: failed to sign challenge") + } + panic("unreachable") +} + +// unmarshal parses an agent message in packet, returning the parsed +// form and the message type of packet. +func unmarshal(packet []byte) (interface{}, error) { + if len(packet) < 1 { + return nil, errors.New("agent: empty packet") + } + var msg interface{} + switch packet[0] { + case agentFailure: + return new(failureAgentMsg), nil + case agentSuccess: + return new(successAgentMsg), nil + case agentIdentitiesAnswer: + msg = new(identitiesAnswerAgentMsg) + case agentSignResponse: + msg = new(signResponseAgentMsg) + case agentV1IdentitiesAnswer: + msg = new(agentV1IdentityMsg) + default: + return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) + } + if err := ssh.Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +type rsaKeyMsg struct { + Type string `sshtype:"17|25"` + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type dsaKeyMsg struct { + Type string `sshtype:"17|25"` + P *big.Int + Q *big.Int + G *big.Int + Y *big.Int + X *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ecdsaKeyMsg struct { + Type string `sshtype:"17|25"` + Curve string + KeyBytes []byte + D *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ed25519KeyMsg struct { + Type string `sshtype:"17|25"` + Pub []byte + Priv []byte + Comments string + Constraints []byte `ssh:"rest"` +} + +// Insert adds a private key to the agent. +func (c *client) insertKey(s interface{}, comment string, constraints []byte) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = ssh.Marshal(rsaKeyMsg{ + Type: ssh.KeyAlgoRSA, + N: k.N, + E: big.NewInt(int64(k.E)), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + Constraints: constraints, + }) + case *dsa.PrivateKey: + req = ssh.Marshal(dsaKeyMsg{ + Type: ssh.KeyAlgoDSA, + P: k.P, + Q: k.Q, + G: k.G, + Y: k.Y, + X: k.X, + Comments: comment, + Constraints: constraints, + }) + case *ecdsa.PrivateKey: + nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) + req = ssh.Marshal(ecdsaKeyMsg{ + Type: "ecdsa-sha2-" + nistID, + Curve: nistID, + KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), + D: k.D, + Comments: comment, + Constraints: constraints, + }) + case *ed25519.PrivateKey: + req = ssh.Marshal(ed25519KeyMsg{ + Type: ssh.KeyAlgoED25519, + Pub: []byte(*k)[32:], + Priv: []byte(*k), + Comments: comment, + Constraints: constraints, + }) + default: + return fmt.Errorf("agent: unsupported key type %T", s) + } + + // if constraints are present then the message type needs to be changed. + if len(constraints) != 0 { + req[0] = agentAddIdConstrained + } + + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +type rsaCertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + D *big.Int + Iqmp *big.Int // IQMP = Inverse Q Mod P + P *big.Int + Q *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type dsaCertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + X *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ecdsaCertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + D *big.Int + Comments string + Constraints []byte `ssh:"rest"` +} + +type ed25519CertMsg struct { + Type string `sshtype:"17|25"` + CertBytes []byte + Pub []byte + Priv []byte + Comments string + Constraints []byte `ssh:"rest"` +} + +// Add adds a private key to the agent. If a certificate is given, +// that certificate is added instead as public key. +func (c *client) Add(key AddedKey) error { + var constraints []byte + + if secs := key.LifetimeSecs; secs != 0 { + constraints = append(constraints, agentConstrainLifetime) + + var secsBytes [4]byte + binary.BigEndian.PutUint32(secsBytes[:], secs) + constraints = append(constraints, secsBytes[:]...) + } + + if key.ConfirmBeforeUse { + constraints = append(constraints, agentConstrainConfirm) + } + + if cert := key.Certificate; cert == nil { + return c.insertKey(key.PrivateKey, key.Comment, constraints) + } else { + return c.insertCert(key.PrivateKey, cert, key.Comment, constraints) + } +} + +func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { + var req []byte + switch k := s.(type) { + case *rsa.PrivateKey: + if len(k.Primes) != 2 { + return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) + } + k.Precompute() + req = ssh.Marshal(rsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comments: comment, + Constraints: constraints, + }) + case *dsa.PrivateKey: + req = ssh.Marshal(dsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + X: k.X, + Comments: comment, + Constraints: constraints, + }) + case *ecdsa.PrivateKey: + req = ssh.Marshal(ecdsaCertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + D: k.D, + Comments: comment, + Constraints: constraints, + }) + case *ed25519.PrivateKey: + req = ssh.Marshal(ed25519CertMsg{ + Type: cert.Type(), + CertBytes: cert.Marshal(), + Pub: []byte(*k)[32:], + Priv: []byte(*k), + Comments: comment, + Constraints: constraints, + }) + default: + return fmt.Errorf("agent: unsupported key type %T", s) + } + + // if constraints are present then the message type needs to be changed. + if len(constraints) != 0 { + req[0] = agentAddIdConstrained + } + + signer, err := ssh.NewSignerFromKey(s) + if err != nil { + return err + } + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return errors.New("agent: signer and cert have different public key") + } + + resp, err := c.call(req) + if err != nil { + return err + } + if _, ok := resp.(*successAgentMsg); ok { + return nil + } + return errors.New("agent: failure") +} + +// Signers provides a callback for client authentication. +func (c *client) Signers() ([]ssh.Signer, error) { + keys, err := c.List() + if err != nil { + return nil, err + } + + var result []ssh.Signer + for _, k := range keys { + result = append(result, &agentKeyringSigner{c, k}) + } + return result, nil +} + +type agentKeyringSigner struct { + agent *client + pub ssh.PublicKey +} + +func (s *agentKeyringSigner) PublicKey() ssh.PublicKey { + return s.pub +} + +func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + // The agent has its own entropy source, so the rand argument is ignored. + return s.agent.Sign(s.pub, data) +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/client_test.go b/vendor/golang.org/x/crypto/ssh/agent/client_test.go new file mode 100644 index 0000000..e33d471 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/client_test.go @@ -0,0 +1,343 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "errors" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +// startAgent executes ssh-agent, and returns a Agent interface to it. +func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { + if testing.Short() { + // ssh-agent is not always available, and the key + // types supported vary by platform. + t.Skip("skipping test due to -short") + } + + bin, err := exec.LookPath("ssh-agent") + if err != nil { + t.Skip("could not find ssh-agent") + } + + cmd := exec.Command(bin, "-s") + out, err := cmd.Output() + if err != nil { + t.Fatalf("cmd.Output: %v", err) + } + + /* Output looks like: + + SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; + SSH_AGENT_PID=15542; export SSH_AGENT_PID; + echo Agent pid 15542; + */ + fields := bytes.Split(out, []byte(";")) + line := bytes.SplitN(fields[0], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AUTH_SOCK" { + t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) + } + socket = string(line[1]) + + line = bytes.SplitN(fields[2], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AGENT_PID" { + t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) + } + pidStr := line[1] + pid, err := strconv.Atoi(string(pidStr)) + if err != nil { + t.Fatalf("Atoi(%q): %v", pidStr, err) + } + + conn, err := net.Dial("unix", string(socket)) + if err != nil { + t.Fatalf("net.Dial: %v", err) + } + + ac := NewClient(conn) + return ac, socket, func() { + proc, _ := os.FindProcess(pid) + if proc != nil { + proc.Kill() + } + conn.Close() + os.RemoveAll(filepath.Dir(socket)) + } +} + +func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { + agent, _, cleanup := startAgent(t) + defer cleanup() + + testAgentInterface(t, agent, key, cert, lifetimeSecs) +} + +func testKeyring(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { + a := NewKeyring() + testAgentInterface(t, a, key, cert, lifetimeSecs) +} + +func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + t.Fatalf("NewSignerFromKey(%T): %v", key, err) + } + // The agent should start up empty. + if keys, err := agent.List(); err != nil { + t.Fatalf("RequestIdentities: %v", err) + } else if len(keys) > 0 { + t.Fatalf("got %d keys, want 0: %v", len(keys), keys) + } + + // Attempt to insert the key, with certificate if specified. + var pubKey ssh.PublicKey + if cert != nil { + err = agent.Add(AddedKey{ + PrivateKey: key, + Certificate: cert, + Comment: "comment", + LifetimeSecs: lifetimeSecs, + }) + pubKey = cert + } else { + err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs}) + pubKey = signer.PublicKey() + } + if err != nil { + t.Fatalf("insert(%T): %v", key, err) + } + + // Did the key get inserted successfully? + if keys, err := agent.List(); err != nil { + t.Fatalf("List: %v", err) + } else if len(keys) != 1 { + t.Fatalf("got %v, want 1 key", keys) + } else if keys[0].Comment != "comment" { + t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") + } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { + t.Fatalf("key mismatch") + } + + // Can the agent make a valid signature? + data := []byte("hello") + sig, err := agent.Sign(pubKey, data) + if err != nil { + t.Fatalf("Sign(%s): %v", pubKey.Type(), err) + } + + if err := pubKey.Verify(data, sig); err != nil { + t.Fatalf("Verify(%s): %v", pubKey.Type(), err) + } + + // If the key has a lifetime, is it removed when it should be? + if lifetimeSecs > 0 { + time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond) + keys, err := agent.List() + if err != nil { + t.Fatalf("List: %v", err) + } + if len(keys) > 0 { + t.Fatalf("key not expired") + } + } + +} + +func TestAgent(t *testing.T) { + for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} { + testAgent(t, testPrivateKeys[keyType], nil, 0) + testKeyring(t, testPrivateKeys[keyType], nil, 1) + } +} + +func TestCert(t *testing.T) { + cert := &ssh.Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: ssh.CertTimeInfinity, + CertType: ssh.UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + testAgent(t, testPrivateKeys["rsa"], cert, 0) + testKeyring(t, testPrivateKeys["rsa"], cert, 1) +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func TestAuth(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + agent, _, cleanup := startAgent(t) + defer cleanup() + + if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { + t.Errorf("Add: %v", err) + } + + serverConf := ssh.ServerConfig{} + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, errors.New("pubkey rejected") + } + + go func() { + conn, _, _, err := ssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + conn.Close() + }() + + conf := ssh.ClientConfig{} + conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) + conn, _, _, err := ssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + conn.Close() +} + +func TestLockClient(t *testing.T) { + agent, _, cleanup := startAgent(t) + defer cleanup() + testLockAgent(agent, t) +} + +func testLockAgent(agent Agent, t *testing.T) { + if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil { + t.Errorf("Add: %v", err) + } + if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil { + t.Errorf("Add: %v", err) + } + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 2 { + t.Errorf("Want 2 keys, got %v", keys) + } + + passphrase := []byte("secret") + if err := agent.Lock(passphrase); err != nil { + t.Errorf("Lock: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 0 { + t.Errorf("Want 0 keys, got %v", keys) + } + + signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) + if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { + t.Fatalf("Sign did not fail") + } + + if err := agent.Remove(signer.PublicKey()); err == nil { + t.Fatalf("Remove did not fail") + } + + if err := agent.RemoveAll(); err == nil { + t.Fatalf("RemoveAll did not fail") + } + + if err := agent.Unlock(nil); err == nil { + t.Errorf("Unlock with wrong passphrase succeeded") + } + if err := agent.Unlock(passphrase); err != nil { + t.Errorf("Unlock: %v", err) + } + + if err := agent.Remove(signer.PublicKey()); err != nil { + t.Fatalf("Remove: %v", err) + } + + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 1 { + t.Errorf("Want 1 keys, got %v", keys) + } +} + +func TestAgentLifetime(t *testing.T) { + agent, _, cleanup := startAgent(t) + defer cleanup() + + for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { + // Add private keys to the agent. + err := agent.Add(AddedKey{ + PrivateKey: testPrivateKeys[keyType], + Comment: "comment", + LifetimeSecs: 1, + }) + if err != nil { + t.Fatalf("add: %v", err) + } + // Add certs to the agent. + cert := &ssh.Certificate{ + Key: testPublicKeys[keyType], + ValidBefore: ssh.CertTimeInfinity, + CertType: ssh.UserCert, + } + cert.SignCert(rand.Reader, testSigners[keyType]) + err = agent.Add(AddedKey{ + PrivateKey: testPrivateKeys[keyType], + Certificate: cert, + Comment: "comment", + LifetimeSecs: 1, + }) + if err != nil { + t.Fatalf("add: %v", err) + } + } + time.Sleep(1100 * time.Millisecond) + if keys, err := agent.List(); err != nil { + t.Errorf("List: %v", err) + } else if len(keys) != 0 { + t.Errorf("Want 0 keys, got %v", len(keys)) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/example_test.go b/vendor/golang.org/x/crypto/ssh/agent/example_test.go new file mode 100644 index 0000000..c1130f7 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/example_test.go @@ -0,0 +1,40 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent_test + +import ( + "log" + "os" + "net" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func ExampleClientAgent() { + // ssh-agent has a UNIX socket under $SSH_AUTH_SOCK + socket := os.Getenv("SSH_AUTH_SOCK") + conn, err := net.Dial("unix", socket) + if err != nil { + log.Fatalf("net.Dial: %v", err) + } + agentClient := agent.NewClient(conn) + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + // Use a callback rather than PublicKeys + // so we only consult the agent once the remote server + // wants it. + ssh.PublicKeysCallback(agentClient.Signers), + }, + } + + sshc, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatalf("Dial: %v", err) + } + // .. use sshc + sshc.Close() +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/forward.go b/vendor/golang.org/x/crypto/ssh/agent/forward.go new file mode 100644 index 0000000..fd24ba9 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/forward.go @@ -0,0 +1,103 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "errors" + "io" + "net" + "sync" + + "golang.org/x/crypto/ssh" +) + +// RequestAgentForwarding sets up agent forwarding for the session. +// ForwardToAgent or ForwardToRemote should be called to route +// the authentication requests. +func RequestAgentForwarding(session *ssh.Session) error { + ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) + if err != nil { + return err + } + if !ok { + return errors.New("forwarding request denied") + } + return nil +} + +// ForwardToAgent routes authentication requests to the given keyring. +func ForwardToAgent(client *ssh.Client, keyring Agent) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go ssh.DiscardRequests(reqs) + go func() { + ServeAgent(keyring, channel) + channel.Close() + }() + } + }() + return nil +} + +const channelType = "auth-agent@openssh.com" + +// ForwardToRemote routes authentication requests to the ssh-agent +// process serving on the given unix socket. +func ForwardToRemote(client *ssh.Client, addr string) error { + channels := client.HandleChannelOpen(channelType) + if channels == nil { + return errors.New("agent: already have handler for " + channelType) + } + conn, err := net.Dial("unix", addr) + if err != nil { + return err + } + conn.Close() + + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + if err != nil { + continue + } + go ssh.DiscardRequests(reqs) + go forwardUnixSocket(channel, addr) + } + }() + return nil +} + +func forwardUnixSocket(channel ssh.Channel, addr string) { + conn, err := net.Dial("unix", addr) + if err != nil { + return + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(conn, channel) + conn.(*net.UnixConn).CloseWrite() + wg.Done() + }() + go func() { + io.Copy(channel, conn) + channel.CloseWrite() + wg.Done() + }() + + wg.Wait() + conn.Close() + channel.Close() +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/keyring.go b/vendor/golang.org/x/crypto/ssh/agent/keyring.go new file mode 100644 index 0000000..a6ba06a --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/keyring.go @@ -0,0 +1,215 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +type privKey struct { + signer ssh.Signer + comment string + expire *time.Time +} + +type keyring struct { + mu sync.Mutex + keys []privKey + + locked bool + passphrase []byte +} + +var errLocked = errors.New("agent: locked") + +// NewKeyring returns an Agent that holds keys in memory. It is safe +// for concurrent use by multiple goroutines. +func NewKeyring() Agent { + return &keyring{} +} + +// RemoveAll removes all identities. +func (r *keyring) RemoveAll() error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.keys = nil + return nil +} + +// removeLocked does the actual key removal. The caller must already be holding the +// keyring mutex. +func (r *keyring) removeLocked(want []byte) error { + found := false + for i := 0; i < len(r.keys); { + if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { + found = true + r.keys[i] = r.keys[len(r.keys)-1] + r.keys = r.keys[:len(r.keys)-1] + continue + } else { + i++ + } + } + + if !found { + return errors.New("agent: key not found") + } + return nil +} + +// Remove removes all identities with the given public key. +func (r *keyring) Remove(key ssh.PublicKey) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + return r.removeLocked(key.Marshal()) +} + +// Lock locks the agent. Sign and Remove will fail, and List will return an empty list. +func (r *keyring) Lock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + + r.locked = true + r.passphrase = passphrase + return nil +} + +// Unlock undoes the effect of Lock +func (r *keyring) Unlock(passphrase []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + if !r.locked { + return errors.New("agent: not locked") + } + if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { + return fmt.Errorf("agent: incorrect passphrase") + } + + r.locked = false + r.passphrase = nil + return nil +} + +// expireKeysLocked removes expired keys from the keyring. If a key was added +// with a lifetimesecs contraint and seconds >= lifetimesecs seconds have +// ellapsed, it is removed. The caller *must* be holding the keyring mutex. +func (r *keyring) expireKeysLocked() { + for _, k := range r.keys { + if k.expire != nil && time.Now().After(*k.expire) { + r.removeLocked(k.signer.PublicKey().Marshal()) + } + } +} + +// List returns the identities known to the agent. +func (r *keyring) List() ([]*Key, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + // section 2.7: locked agents return empty. + return nil, nil + } + + r.expireKeysLocked() + var ids []*Key + for _, k := range r.keys { + pub := k.signer.PublicKey() + ids = append(ids, &Key{ + Format: pub.Type(), + Blob: pub.Marshal(), + Comment: k.comment}) + } + return ids, nil +} + +// Insert adds a private key to the keyring. If a certificate +// is given, that certificate is added as public key. Note that +// any constraints given are ignored. +func (r *keyring) Add(key AddedKey) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return errLocked + } + signer, err := ssh.NewSignerFromKey(key.PrivateKey) + + if err != nil { + return err + } + + if cert := key.Certificate; cert != nil { + signer, err = ssh.NewCertSigner(cert, signer) + if err != nil { + return err + } + } + + p := privKey{ + signer: signer, + comment: key.Comment, + } + + if key.LifetimeSecs > 0 { + t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second) + p.expire = &t + } + + r.keys = append(r.keys, p) + + return nil +} + +// Sign returns a signature for the data. +func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + r.expireKeysLocked() + wanted := key.Marshal() + for _, k := range r.keys { + if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { + return k.signer.Sign(rand.Reader, data) + } + } + return nil, errors.New("not found") +} + +// Signers returns signers for all the known keys. +func (r *keyring) Signers() ([]ssh.Signer, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.locked { + return nil, errLocked + } + + r.expireKeysLocked() + s := make([]ssh.Signer, 0, len(r.keys)) + for _, k := range r.keys { + s = append(s, k.signer) + } + return s, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/keyring_test.go b/vendor/golang.org/x/crypto/ssh/agent/keyring_test.go new file mode 100644 index 0000000..e5d50e7 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/keyring_test.go @@ -0,0 +1,76 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import "testing" + +func addTestKey(t *testing.T, a Agent, keyName string) { + err := a.Add(AddedKey{ + PrivateKey: testPrivateKeys[keyName], + Comment: keyName, + }) + if err != nil { + t.Fatalf("failed to add key %q: %v", keyName, err) + } +} + +func removeTestKey(t *testing.T, a Agent, keyName string) { + err := a.Remove(testPublicKeys[keyName]) + if err != nil { + t.Fatalf("failed to remove key %q: %v", keyName, err) + } +} + +func validateListedKeys(t *testing.T, a Agent, expectedKeys []string) { + listedKeys, err := a.List() + if err != nil { + t.Fatalf("failed to list keys: %v", err) + return + } + actualKeys := make(map[string]bool) + for _, key := range listedKeys { + actualKeys[key.Comment] = true + } + + matchedKeys := make(map[string]bool) + for _, expectedKey := range expectedKeys { + if !actualKeys[expectedKey] { + t.Fatalf("expected key %q, but was not found", expectedKey) + } else { + matchedKeys[expectedKey] = true + } + } + + for actualKey := range actualKeys { + if !matchedKeys[actualKey] { + t.Fatalf("key %q was found, but was not expected", actualKey) + } + } +} + +func TestKeyringAddingAndRemoving(t *testing.T) { + keyNames := []string{"dsa", "ecdsa", "rsa", "user"} + + // add all test private keys + k := NewKeyring() + for _, keyName := range keyNames { + addTestKey(t, k, keyName) + } + validateListedKeys(t, k, keyNames) + + // remove a key in the middle + keyToRemove := keyNames[1] + keyNames = append(keyNames[:1], keyNames[2:]...) + + removeTestKey(t, k, keyToRemove) + validateListedKeys(t, k, keyNames) + + // remove all keys + err := k.RemoveAll() + if err != nil { + t.Fatalf("failed to remove all keys: %v", err) + } + validateListedKeys(t, k, []string{}) +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/server.go b/vendor/golang.org/x/crypto/ssh/agent/server.go new file mode 100644 index 0000000..68a333f --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/server.go @@ -0,0 +1,451 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "math/big" + + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" +) + +// Server wraps an Agent and uses it to implement the agent side of +// the SSH-agent, wire protocol. +type server struct { + agent Agent +} + +func (s *server) processRequestBytes(reqData []byte) []byte { + rep, err := s.processRequest(reqData) + if err != nil { + if err != errLocked { + // TODO(hanwen): provide better logging interface? + log.Printf("agent %d: %v", reqData[0], err) + } + return []byte{agentFailure} + } + + if err == nil && rep == nil { + return []byte{agentSuccess} + } + + return ssh.Marshal(rep) +} + +func marshalKey(k *Key) []byte { + var record struct { + Blob []byte + Comment string + } + record.Blob = k.Marshal() + record.Comment = k.Comment + + return ssh.Marshal(&record) +} + +// See [PROTOCOL.agent], section 2.5.1. +const agentV1IdentitiesAnswer = 2 + +type agentV1IdentityMsg struct { + Numkeys uint32 `sshtype:"2"` +} + +type agentRemoveIdentityMsg struct { + KeyBlob []byte `sshtype:"18"` +} + +type agentLockMsg struct { + Passphrase []byte `sshtype:"22"` +} + +type agentUnlockMsg struct { + Passphrase []byte `sshtype:"23"` +} + +func (s *server) processRequest(data []byte) (interface{}, error) { + switch data[0] { + case agentRequestV1Identities: + return &agentV1IdentityMsg{0}, nil + + case agentRemoveAllV1Identities: + return nil, nil + + case agentRemoveIdentity: + var req agentRemoveIdentityMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) + + case agentRemoveAllIdentities: + return nil, s.agent.RemoveAll() + + case agentLock: + var req agentLockMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + return nil, s.agent.Lock(req.Passphrase) + + case agentUnlock: + var req agentLockMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + return nil, s.agent.Unlock(req.Passphrase) + + case agentSignRequest: + var req signRequestAgentMsg + if err := ssh.Unmarshal(data, &req); err != nil { + return nil, err + } + + var wk wireKey + if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { + return nil, err + } + + k := &Key{ + Format: wk.Format, + Blob: req.KeyBlob, + } + + sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags. + if err != nil { + return nil, err + } + return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil + + case agentRequestIdentities: + keys, err := s.agent.List() + if err != nil { + return nil, err + } + + rep := identitiesAnswerAgentMsg{ + NumKeys: uint32(len(keys)), + } + for _, k := range keys { + rep.Keys = append(rep.Keys, marshalKey(k)...) + } + return rep, nil + + case agentAddIdConstrained, agentAddIdentity: + return nil, s.insertIdentity(data) + } + + return nil, fmt.Errorf("unknown opcode %d", data[0]) +} + +func parseRSAKey(req []byte) (*AddedKey, error) { + var k rsaKeyMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + if k.E.BitLen() > 30 { + return nil, errors.New("agent: RSA public exponent too large") + } + priv := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + E: int(k.E.Int64()), + N: k.N, + }, + D: k.D, + Primes: []*big.Int{k.P, k.Q}, + } + priv.Precompute() + + return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil +} + +func parseEd25519Key(req []byte) (*AddedKey, error) { + var k ed25519KeyMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + priv := ed25519.PrivateKey(k.Priv) + return &AddedKey{PrivateKey: &priv, Comment: k.Comments}, nil +} + +func parseDSAKey(req []byte) (*AddedKey, error) { + var k dsaKeyMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + priv := &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Y, + }, + X: k.X, + } + + return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil +} + +func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) { + priv = &ecdsa.PrivateKey{ + D: privScalar, + } + + switch curveName { + case "nistp256": + priv.Curve = elliptic.P256() + case "nistp384": + priv.Curve = elliptic.P384() + case "nistp521": + priv.Curve = elliptic.P521() + default: + return nil, fmt.Errorf("agent: unknown curve %q", curveName) + } + + priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes) + if priv.X == nil || priv.Y == nil { + return nil, errors.New("agent: point not on curve") + } + + return priv, nil +} + +func parseEd25519Cert(req []byte) (*AddedKey, error) { + var k ed25519CertMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + pubKey, err := ssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + priv := ed25519.PrivateKey(k.Priv) + cert, ok := pubKey.(*ssh.Certificate) + if !ok { + return nil, errors.New("agent: bad ED25519 certificate") + } + return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil +} + +func parseECDSAKey(req []byte) (*AddedKey, error) { + var k ecdsaKeyMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + + priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D) + if err != nil { + return nil, err + } + + return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil +} + +func parseRSACert(req []byte) (*AddedKey, error) { + var k rsaCertMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + + pubKey, err := ssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + + cert, ok := pubKey.(*ssh.Certificate) + if !ok { + return nil, errors.New("agent: bad RSA certificate") + } + + // An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go + var rsaPub struct { + Name string + E *big.Int + N *big.Int + } + if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil { + return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) + } + + if rsaPub.E.BitLen() > 30 { + return nil, errors.New("agent: RSA public exponent too large") + } + + priv := rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + E: int(rsaPub.E.Int64()), + N: rsaPub.N, + }, + D: k.D, + Primes: []*big.Int{k.Q, k.P}, + } + priv.Precompute() + + return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil +} + +func parseDSACert(req []byte) (*AddedKey, error) { + var k dsaCertMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + pubKey, err := ssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + cert, ok := pubKey.(*ssh.Certificate) + if !ok { + return nil, errors.New("agent: bad DSA certificate") + } + + // A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go + var w struct { + Name string + P, Q, G, Y *big.Int + } + if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil { + return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err) + } + + priv := &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + }, + Y: w.Y, + }, + X: k.X, + } + + return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil +} + +func parseECDSACert(req []byte) (*AddedKey, error) { + var k ecdsaCertMsg + if err := ssh.Unmarshal(req, &k); err != nil { + return nil, err + } + + pubKey, err := ssh.ParsePublicKey(k.CertBytes) + if err != nil { + return nil, err + } + cert, ok := pubKey.(*ssh.Certificate) + if !ok { + return nil, errors.New("agent: bad ECDSA certificate") + } + + // An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go + var ecdsaPub struct { + Name string + ID string + Key []byte + } + if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil { + return nil, err + } + + priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D) + if err != nil { + return nil, err + } + + return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil +} + +func (s *server) insertIdentity(req []byte) error { + var record struct { + Type string `sshtype:"17|25"` + Rest []byte `ssh:"rest"` + } + + if err := ssh.Unmarshal(req, &record); err != nil { + return err + } + + var addedKey *AddedKey + var err error + + switch record.Type { + case ssh.KeyAlgoRSA: + addedKey, err = parseRSAKey(req) + case ssh.KeyAlgoDSA: + addedKey, err = parseDSAKey(req) + case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521: + addedKey, err = parseECDSAKey(req) + case ssh.KeyAlgoED25519: + addedKey, err = parseEd25519Key(req) + case ssh.CertAlgoRSAv01: + addedKey, err = parseRSACert(req) + case ssh.CertAlgoDSAv01: + addedKey, err = parseDSACert(req) + case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01: + addedKey, err = parseECDSACert(req) + case ssh.CertAlgoED25519v01: + addedKey, err = parseEd25519Cert(req) + default: + return fmt.Errorf("agent: not implemented: %q", record.Type) + } + + if err != nil { + return err + } + return s.agent.Add(*addedKey) +} + +// ServeAgent serves the agent protocol on the given connection. It +// returns when an I/O error occurs. +func ServeAgent(agent Agent, c io.ReadWriter) error { + s := &server{agent} + + var length [4]byte + for { + if _, err := io.ReadFull(c, length[:]); err != nil { + return err + } + l := binary.BigEndian.Uint32(length[:]) + if l > maxAgentResponseBytes { + // We also cap requests. + return fmt.Errorf("agent: request too large: %d", l) + } + + req := make([]byte, l) + if _, err := io.ReadFull(c, req); err != nil { + return err + } + + repData := s.processRequestBytes(req) + if len(repData) > maxAgentResponseBytes { + return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) + } + + binary.BigEndian.PutUint32(length[:], uint32(len(repData))) + if _, err := c.Write(length[:]); err != nil { + return err + } + if _, err := c.Write(repData); err != nil { + return err + } + } +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/server_test.go b/vendor/golang.org/x/crypto/ssh/agent/server_test.go new file mode 100644 index 0000000..ec9cdee --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/server_test.go @@ -0,0 +1,207 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package agent + +import ( + "crypto" + "crypto/rand" + "fmt" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestServer(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + client := NewClient(c1) + + go ServeAgent(NewKeyring(), c2) + + testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0) +} + +func TestLockServer(t *testing.T) { + testLockAgent(NewKeyring(), t) +} + +func TestSetupForwardAgent(t *testing.T) { + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + defer a.Close() + defer b.Close() + + _, socket, cleanup := startAgent(t) + defer cleanup() + + serverConf := ssh.ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + incoming := make(chan *ssh.ServerConn, 1) + go func() { + conn, _, _, err := ssh.NewServerConn(a, &serverConf) + if err != nil { + t.Fatalf("Server: %v", err) + } + incoming <- conn + }() + + conf := ssh.ClientConfig{} + conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + client := ssh.NewClient(conn, chans, reqs) + + if err := ForwardToRemote(client, socket); err != nil { + t.Fatalf("SetupForwardAgent: %v", err) + } + + server := <-incoming + ch, reqs, err := server.OpenChannel(channelType, nil) + if err != nil { + t.Fatalf("OpenChannel(%q): %v", channelType, err) + } + go ssh.DiscardRequests(reqs) + + agentClient := NewClient(ch) + testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0) + conn.Close() +} + +func TestV1ProtocolMessages(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + c := NewClient(c1) + + go ServeAgent(NewKeyring(), c2) + + testV1ProtocolMessages(t, c.(*client)) +} + +func testV1ProtocolMessages(t *testing.T, c *client) { + reply, err := c.call([]byte{agentRequestV1Identities}) + if err != nil { + t.Fatalf("v1 request all failed: %v", err) + } + if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 { + t.Fatalf("invalid request all response: %#v", reply) + } + + reply, err = c.call([]byte{agentRemoveAllV1Identities}) + if err != nil { + t.Fatalf("v1 remove all failed: %v", err) + } + if _, ok := reply.(*successAgentMsg); !ok { + t.Fatalf("invalid remove all response: %#v", reply) + } +} + +func verifyKey(sshAgent Agent) error { + keys, err := sshAgent.List() + if err != nil { + return fmt.Errorf("listing keys: %v", err) + } + + if len(keys) != 1 { + return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys)) + } + + buf := make([]byte, 128) + if _, err := rand.Read(buf); err != nil { + return fmt.Errorf("rand: %v", err) + } + + sig, err := sshAgent.Sign(keys[0], buf) + if err != nil { + return fmt.Errorf("sign: %v", err) + } + + if err := keys[0].Verify(buf, sig); err != nil { + return fmt.Errorf("verify: %v", err) + } + return nil +} + +func addKeyToAgent(key crypto.PrivateKey) error { + sshAgent := NewKeyring() + if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil { + return fmt.Errorf("add: %v", err) + } + return verifyKey(sshAgent) +} + +func TestKeyTypes(t *testing.T) { + for k, v := range testPrivateKeys { + if err := addKeyToAgent(v); err != nil { + t.Errorf("error adding key type %s, %v", k, err) + } + if err := addCertToAgentSock(v, nil); err != nil { + t.Errorf("error adding key type %s, %v", k, err) + } + } +} + +func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error { + a, b, err := netPipe() + if err != nil { + return err + } + agentServer := NewKeyring() + go ServeAgent(agentServer, a) + + agentClient := NewClient(b) + if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil { + return fmt.Errorf("add: %v", err) + } + return verifyKey(agentClient) +} + +func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error { + sshAgent := NewKeyring() + if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil { + return fmt.Errorf("add: %v", err) + } + return verifyKey(sshAgent) +} + +func TestCertTypes(t *testing.T) { + for keyType, key := range testPublicKeys { + cert := &ssh.Certificate{ + ValidPrincipals: []string{"gopher1"}, + ValidAfter: 0, + ValidBefore: ssh.CertTimeInfinity, + Key: key, + Serial: 1, + CertType: ssh.UserCert, + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil { + t.Fatalf("signcert: %v", err) + } + if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil { + t.Fatalf("%v", err) + } + if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil { + t.Fatalf("%v", err) + } + } +} diff --git a/vendor/golang.org/x/crypto/ssh/agent/testdata_test.go b/vendor/golang.org/x/crypto/ssh/agent/testdata_test.go new file mode 100644 index 0000000..cc42a87 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/agent/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package agent + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]ssh.Signer + testPublicKeys map[string]ssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]ssh.Signer, n) + testPublicKeys = make(map[string]ssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &ssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/benchmark_test.go b/vendor/golang.org/x/crypto/ssh/benchmark_test.go new file mode 100644 index 0000000..d9f7eb9 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/benchmark_test.go @@ -0,0 +1,122 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "io" + "net" + "testing" +) + +type server struct { + *ServerConn + chans <-chan NewChannel +} + +func newServer(c net.Conn, conf *ServerConfig) (*server, error) { + sconn, chans, reqs, err := NewServerConn(c, conf) + if err != nil { + return nil, err + } + go DiscardRequests(reqs) + return &server{sconn, chans}, nil +} + +func (s *server) Accept() (NewChannel, error) { + n, ok := <-s.chans + if !ok { + return nil, io.EOF + } + return n, nil +} + +func sshPipe() (Conn, *server, error) { + c1, c2, err := netPipe() + if err != nil { + return nil, nil, err + } + + clientConf := ClientConfig{ + User: "user", + } + serverConf := ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + done := make(chan *server, 1) + go func() { + server, err := newServer(c2, &serverConf) + if err != nil { + done <- nil + } + done <- server + }() + + client, _, reqs, err := NewClientConn(c1, "", &clientConf) + if err != nil { + return nil, nil, err + } + + server := <-done + if server == nil { + return nil, nil, errors.New("server handshake failed.") + } + go DiscardRequests(reqs) + + return client, server, nil +} + +func BenchmarkEndToEnd(b *testing.B) { + b.StopTimer() + + client, server, err := sshPipe() + if err != nil { + b.Fatalf("sshPipe: %v", err) + } + + defer client.Close() + defer server.Close() + + size := (1 << 20) + input := make([]byte, size) + output := make([]byte, size) + b.SetBytes(int64(size)) + done := make(chan int, 1) + + go func() { + newCh, err := server.Accept() + if err != nil { + b.Fatalf("Client: %v", err) + } + ch, incoming, err := newCh.Accept() + go DiscardRequests(incoming) + for i := 0; i < b.N; i++ { + if _, err := io.ReadFull(ch, output); err != nil { + b.Fatalf("ReadFull: %v", err) + } + } + ch.Close() + done <- 1 + }() + + ch, in, err := client.OpenChannel("speed", nil) + if err != nil { + b.Fatalf("OpenChannel: %v", err) + } + go DiscardRequests(in) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := ch.Write(input); err != nil { + b.Fatalf("WriteFull: %v", err) + } + } + ch.Close() + b.StopTimer() + + <-done +} diff --git a/vendor/golang.org/x/crypto/ssh/buffer.go b/vendor/golang.org/x/crypto/ssh/buffer.go new file mode 100644 index 0000000..6931b51 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/buffer.go @@ -0,0 +1,98 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: newCond(), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive os.EOF. +func (b *buffer) eof() error { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() + return nil +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/vendor/golang.org/x/crypto/ssh/buffer_test.go b/vendor/golang.org/x/crypto/ssh/buffer_test.go new file mode 100644 index 0000000..d5781cb --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/buffer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "testing" +) + +var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") + +func TestBufferReadwrite(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + r, _ := b.Read(make([]byte, 10)) + if r != 10 { + t.Fatalf("Expected written == read == 10, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + r, _ = b.Read(make([]byte, 10)) + if r != 5 { + t.Fatalf("Expected written == read == 5, written: 5, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:10]) + r, _ = b.Read(make([]byte, 5)) + if r != 5 { + t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + b.write(alphabet[5:15]) + r, _ = b.Read(make([]byte, 10)) + r2, _ := b.Read(make([]byte, 10)) + if r != 10 || r2 != 5 || 15 != r+r2 { + t.Fatal("Expected written == read == 15") + } +} + +func TestBufferClose(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + b.eof() + _, err := b.Read(make([]byte, 5)) + if err != nil { + t.Fatal("expected read of 5 to not return EOF") + } + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err := b.Read(make([]byte, 5)) + r2, err2 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || err != nil || err2 != nil { + t.Fatal("expected reads of 5 and 5") + } + + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err = b.Read(make([]byte, 5)) + r2, err2 = b.Read(make([]byte, 10)) + r3, err3 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { + t.Fatal("expected reads of 5 and 5 and 0, with EOF") + } + + b = newBuffer() + b.write(make([]byte, 5)) + b.write(make([]byte, 10)) + b.eof() + r, err = b.Read(make([]byte, 9)) + r2, err2 = b.Read(make([]byte, 3)) + r3, err3 = b.Read(make([]byte, 3)) + r4, err4 := b.Read(make([]byte, 10)) + if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { + t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) + } + if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { + t.Fatal("Expected written == read == 15", r, r2, r3, r4) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/certs.go b/vendor/golang.org/x/crypto/ssh/certs.go new file mode 100644 index 0000000..6331c94 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/certs.go @@ -0,0 +1,503 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sort" + "time" +) + +// These constants from [PROTOCOL.certkeys] represent the algorithm names +// for certificate types supported by this package. +const ( + CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" + CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" + CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" + CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" + CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com" +) + +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. +const ( + UserCert = 1 + HostCert = 2 +) + +// Signature represents a cryptographic signature. +type Signature struct { + Format string + Blob []byte +} + +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 + +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature +} + +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} + +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) + } + return to +} + +type optionsTuple struct { + Key string + Value []byte +} + +type optionsTupleValue struct { + Value string +} + +// serialize a map of critical options or extensions +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty string value +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for key := range tups { + keys = append(keys, key) + } + sort.Strings(keys) + + var ret []byte + for _, key := range keys { + s := optionsTuple{Key: key} + if value := tups[key]; len(value) > 0 { + s.Value = Marshal(&optionsTupleValue{value}) + } + ret = append(ret, Marshal(&s)...) + } + return ret +} + +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty option value +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool + + for len(in) > 0 { + var key, val, extra []byte + var ok bool + + if key, in, ok = parseString(in); !ok { + return nil, errShortRead + } + keyStr := string(key) + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && keyStr <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") + } + lastKey, haveLastKey = keyStr, true + // the next field is a data field, which if non-empty has a string embedded + if val, in, ok = parseString(in); !ok { + return nil, errShortRead + } + if len(val) > 0 { + val, extra, ok = parseString(val) + if !ok { + return nil, errShortRead + } + if len(extra) > 0 { + return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") + } + tups[keyStr] = string(val) + } else { + tups[keyStr] = "" + } + } + return tups, nil +} + +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } + + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err + } + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, + } + + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err + } + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err + } + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err + } + + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil +} + +type openSSHCertSigner struct { + pub *Certificate + signer Signer +} + +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return nil, errors.New("ssh: signer and cert have different public key") + } + + return &openSSHCertSigner{cert, signer}, nil +} + +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) +} + +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub +} + +const sourceAddressCriticalOption = "source-address" + +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsAuthority should return true if the key is recognized as + // an authority. This allows for certificates to be signed by other + // certificates. + IsAuthority func(auth PublicKey) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) + if !ok { + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") + } + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) + } + + return c.CheckCert(addr, cert) +} + +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err + } + + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) + } + + for opt, _ := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } + } + + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } + } + + if !c.IsAuthority(cert.SignatureKey) { + return fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + clock := c.Clock + if clock == nil { + clock = time.Now + } + + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") + } + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") + } + + return nil +} + +// SignCert sets c.SignatureKey to the authority's public key and stores a +// Signature, by authority, in the certificate. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err + } + c.SignatureKey = authority.PublicKey() + + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err + } + c.Signature = sig + return nil +} + +var certAlgoNames = map[string]string{ + KeyAlgoRSA: CertAlgoRSAv01, + KeyAlgoDSA: CertAlgoDSAv01, + KeyAlgoECDSA256: CertAlgoECDSA256v01, + KeyAlgoECDSA384: CertAlgoECDSA384v01, + KeyAlgoECDSA521: CertAlgoECDSA521v01, + KeyAlgoED25519: CertAlgoED25519v01, +} + +// certToPrivAlgo returns the underlying algorithm for a certificate algorithm. +// Panics if a non-certificate algorithm is passed. +func certToPrivAlgo(algo string) string { + for privAlgo, pubAlgo := range certAlgoNames { + if pubAlgo == algo { + return privAlgo + } + } + panic("unknown cert algorithm") +} + +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] +} + +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), + } + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) + } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) + + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result +} + +// Type returns the key name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + algo, ok := certAlgoNames[c.Key.Type()] + if !ok { + panic("unknown cert key type " + c.Key.Type()) + } + return algo +} + +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) +} + +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { + return + } + + out = &Signature{ + Format: string(format), + } + + if out.Blob, in, ok = parseString(in); !ok { + return + } + + return out, in, ok +} + +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { + return + } + + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { + return nil, nil, false + } + return +} diff --git a/vendor/golang.org/x/crypto/ssh/certs_test.go b/vendor/golang.org/x/crypto/ssh/certs_test.go new file mode 100644 index 0000000..c5f2e53 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/certs_test.go @@ -0,0 +1,216 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "reflect" + "testing" + "time" +) + +// Cert generated by ssh-keygen 6.0p1 Debian-4. +// % ssh-keygen -s ca-key -I test user-key +const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=` + +func TestParseCert(t *testing.T) { + authKeyBytes := []byte(exampleSSHCert) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + + if _, ok := key.(*Certificate); !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3 +// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub +// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN +// Critical Options: +// force-command /bin/sleep +// source-address 192.168.1.0/24 +// Extensions: +// permit-X11-forwarding +// permit-agent-forwarding +// permit-port-forwarding +// permit-pty +// permit-user-rc +const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` + +func TestParseCertWithOptions(t *testing.T) { + opts := map[string]string{ + "source-address": "192.168.1.0/24", + "force-command": "/bin/sleep", + } + exts := map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + } + authKeyBytes := []byte(exampleSSHCertWithOptions) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + cert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + if !reflect.DeepEqual(cert.CriticalOptions, opts) { + t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) + } + if !reflect.DeepEqual(cert.Extensions, exts) { + t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) + } + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +func TestValidateCert(t *testing.T) { + key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + validCert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + checker := CertChecker{} + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) + } + + if err := checker.CheckCert("user", validCert); err != nil { + t.Errorf("Unable to validate certificate: %v", err) + } + invalidCert := &Certificate{ + Key: testPublicKeys["rsa"], + SignatureKey: testPublicKeys["ecdsa"], + ValidBefore: CertTimeInfinity, + Signature: &Signature{}, + } + if err := checker.CheckCert("user", invalidCert); err == nil { + t.Error("Invalid cert signature passed validation") + } +} + +func TestValidateCertTime(t *testing.T) { + cert := Certificate{ + ValidPrincipals: []string{"user"}, + Key: testPublicKeys["rsa"], + ValidAfter: 50, + ValidBefore: 100, + } + + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + for ts, ok := range map[int64]bool{ + 25: false, + 50: true, + 99: true, + 100: false, + 125: false, + } { + checker := CertChecker{ + Clock: func() time.Time { return time.Unix(ts, 0) }, + } + checker.IsAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), + testPublicKeys["ecdsa"].Marshal()) + } + + if v := checker.CheckCert("user", &cert); (v == nil) != ok { + t.Errorf("Authenticate(%d): %v", ts, v) + } + } +} + +// TODO(hanwen): tests for +// +// host keys: +// * fallbacks + +func TestHostKeyCert(t *testing.T) { + cert := &Certificate{ + ValidPrincipals: []string{"hostname", "hostname.domain"}, + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: HostCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + checker := &CertChecker{ + IsAuthority: func(p PublicKey) bool { + return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) + }, + } + + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Errorf("NewCertSigner: %v", err) + } + + for _, name := range []string{"hostname", "otherhost"} { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errc := make(chan error) + + go func() { + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(certSigner) + _, _, _, err := NewServerConn(c1, &conf) + errc <- err + }() + + config := &ClientConfig{ + User: "user", + HostKeyCallback: checker.CheckHostKey, + } + _, _, _, err = NewClientConn(c2, name, config) + + succeed := name == "hostname" + if (err == nil) != succeed { + t.Fatalf("NewClientConn(%q): %v", name, err) + } + + err = <-errc + if (err == nil) != succeed { + t.Fatalf("NewServerConn(%q): %v", name, err) + } + } +} diff --git a/vendor/golang.org/x/crypto/ssh/channel.go b/vendor/golang.org/x/crypto/ssh/channel.go new file mode 100644 index 0000000..6d709b5 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/channel.go @@ -0,0 +1,633 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "sync" +) + +const ( + minPacketLength = 9 + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) + + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + + // ExtraData returns the arbitrary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + // If the channel is closed before a reply is returned, io.EOF + // is returned. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string + WantReply bool + Payload []byte + + ch *channel + mux *mux +} + +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason uint32 + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +// String converts the rejection reason to human readable form. +func (r RejectionReason) String() string { + switch r { + case Prohibited: + return "administratively prohibited" + case ConnectionFailed: + return "connect failed" + case UnknownChannelType: + return "unknown channel type" + case ResourceShortage: + return "resource shortage" + } + return fmt.Sprintf("unknown reason %d", int(r)) +} + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + mux *mux + + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan interface{} + + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex + + incomingRequests chan *Request + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer + + // windowMu protects myWindow, the flow-control window. + windowMu sync.Mutex + myWindow uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packetPool has a buffer for each extended channel ID to + // save allocations during writes. + packetPool map[uint32][]byte +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (c *channel) writePacket(packet []byte) error { + c.writeMu.Lock() + if c.sentClose { + c.writeMu.Unlock() + return io.EOF + } + c.sentClose = (packet[0] == msgChannelClose) + err := c.mux.conn.writePacket(packet) + c.writeMu.Unlock() + return err +} + +func (c *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) + } + + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], c.remoteId) + return c.writePacket(p) +} + +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if c.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } + + c.writeMu.Lock() + packet := c.packetPool[extendedCode] + // We don't remove the buffer from packetPool, so + // WriteExtended calls from different goroutines will be + // flagged as errors by the race detector. + c.writeMu.Unlock() + + for len(data) > 0 { + space := min(c.maxRemotePayload, len(data)) + if space, err = c.remoteWin.reserve(space); err != nil { + return n, err + } + if want := headerLength + space; uint32(cap(packet)) < want { + packet = make([]byte, want) + } else { + packet = packet[:want] + } + + todo := data[:space] + + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], c.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) + } + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = c.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] + } + + c.writeMu.Lock() + c.packetPool[extendedCode] = packet + c.writeMu.Unlock() + + return n, err +} + +func (c *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } + + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) + } + + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > c.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") + } + + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } + + c.windowMu.Lock() + if c.myWindow < length { + c.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + c.myWindow -= length + c.windowMu.Unlock() + + if extended == 1 { + c.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + c.pending.write(data) + } + return nil +} + +func (c *channel) adjustWindow(n uint32) error { + c.windowMu.Lock() + // Since myWindow is managed on our side, and can never exceed + // the initial window setting, we don't worry about overflow. + c.myWindow += uint32(n) + c.windowMu.Unlock() + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: uint32(n), + }) +} + +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + + return n, err +} + +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necessary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (c *channel) responseMessageReceived() error { + if c.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") + } + if c.decided { + return errors.New("ssh: duplicate response received for channel") + } + c.decided = true + return nil +} + +func (c *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return c.handleData(packet) + case msgChannelClose: + c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) + c.mux.chanList.remove(c.localId) + c.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + c.extPending.eof() + c.pending.eof() + return nil + } + + decoded, err := decode(packet) + if err != nil { + return err + } + + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := c.responseMessageReceived(); err != nil { + return err + } + c.mux.chanList.remove(msg.PeersId) + c.msg <- msg + case *channelOpenConfirmMsg: + if err := c.responseMessageReceived(); err != nil { + return err + } + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + c.remoteId = msg.MyId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.MyWindow) + c.msg <- msg + case *windowAdjustMsg: + if !c.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: c, + } + + c.incomingRequests <- &req + default: + c.msg <- msg + } + return nil +} + +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, 16), + msg: make(chan interface{}, 16), + chanType: chanType, + extraData: extraData, + mux: m, + packetPool: make(map[uint32][]byte), + } + ch.localId = m.chanList.add(ch) + return ch +} + +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") + +type extChannel struct { + code uint32 + ch *channel +} + +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} + +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} + +func (c *channel) Accept() (Channel, <-chan *Request, error) { + if c.decided { + return nil, nil, errDecidedAlready + } + c.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersId: c.remoteId, + MyId: c.localId, + MyWindow: c.myWindow, + MaxPacketSize: c.maxIncomingPayload, + } + c.decided = true + if err := c.sendMessage(confirm); err != nil { + return nil, nil, err + } + + return c, c.incomingRequests, nil +} + +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready + } + reject := channelOpenFailureMsg{ + PeersId: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", + } + ch.decided = true + return ch.sendMessage(reject) +} + +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.ReadExtended(data, 0) +} + +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.WriteExtended(data, 0) +} + +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided + } + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersId: ch.remoteId}) +} + +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersId: ch.remoteId}) +} + +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} +} + +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) +} + +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided + } + + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() + } + + msg := channelRequestMsg{ + PeersId: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, + } + + if err := ch.sendMessage(msg); err != nil { + return false, err + } + + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF + } + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) + } + } + + return false, nil +} + +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided + } + + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersId: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersId: ch.remoteId, + } + } + return ch.sendMessage(msg) +} + +func (ch *channel) ChannelType() string { + return ch.chanType +} + +func (ch *channel) ExtraData() []byte { + return ch.extraData +} diff --git a/vendor/golang.org/x/crypto/ssh/cipher.go b/vendor/golang.org/x/crypto/ssh/cipher.go new file mode 100644 index 0000000..34d3917 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/cipher.go @@ -0,0 +1,579 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + "io/ioutil" +) + +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) + +// noneCipher implements cipher.Stream and provides no encryption. It is used +// by the transport before the first key-exchange. +type noneCipher struct{} + +func (c noneCipher) XORKeyStream(dst, src []byte) { + copy(dst, src) +} + +func newAESCTR(key, iv []byte) (cipher.Stream, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(c, iv), nil +} + +func newRC4(key, iv []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) +} + +type streamCipherMode struct { + keySize int + ivSize int + skip int + createFunc func(key, iv []byte) (cipher.Stream, error) +} + +func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { + if len(key) < c.keySize { + panic("ssh: key length too small for cipher") + } + if len(iv) < c.ivSize { + panic("ssh: iv too small for cipher") + } + + stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) + if err != nil { + return nil, err + } + + var streamDump []byte + if c.skip > 0 { + streamDump = make([]byte, 512) + } + + for remainingToDump := c.skip; remainingToDump > 0; { + dumpThisTime := remainingToDump + if dumpThisTime > len(streamDump) { + dumpThisTime = len(streamDump) + } + stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) + remainingToDump -= dumpThisTime + } + + return stream, nil +} + +// cipherModes documents properties of supported ciphers. Ciphers not included +// are not supported and will not be negotiated, even if explicitly requested in +// ClientConfig.Crypto.Ciphers. +var cipherModes = map[string]*streamCipherMode{ + // Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms + // are defined in the order specified in the RFC. + "aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, + "aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, + "aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, + + // Ciphers from RFC4345, which introduces security-improved arcfour ciphers. + // They are defined in the order specified in the RFC. + "arcfour128": {16, 0, 1536, newRC4}, + "arcfour256": {32, 0, 1536, newRC4}, + + // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. + // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and + // RC4) has problems with weak keys, and should be used with caution." + // RFC4345 introduces improved versions of Arcfour. + "arcfour": {16, 0, 0, newRC4}, + + // AES-GCM is not a stream cipher, so it is constructed with a + // special case. If we add any more non-stream ciphers, we + // should invest a cleaner way to do this. + gcmCipherID: {16, 12, 0, nil}, + + // CBC mode is insecure and so is not included in the default config. + // (See http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf). If absolutely + // needed, it's possible to specify a custom Config to enable it. + // You should expect that an active attacker can recover plaintext if + // you do. + aes128cbcID: {16, aes.BlockSize, 0, nil}, + + // 3des-cbc is insecure and is disabled by default. + tripledescbcID: {24, des.BlockSize, 0, nil}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + s.mac.Write(data) + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writePacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + s.mac.Write(s.prefix[:]) + s.mac.Write(packet) + s.mac.Write(padding) + } + + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil +} + +const gcmTagSize = 16 + +func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded.") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + padding := plain[0] + if padding < 4 || padding >= 20 { + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil +} + +// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 +type cbcCipher struct { + mac hash.Hash + macSize uint32 + decrypter cipher.BlockMode + encrypter cipher.BlockMode + + // The following members are to avoid per-packet allocations. + seqNumBytes [4]byte + packetData []byte + macResult []byte + + // Amount of data we should still read to hide which + // verification error triggered. + oracleCamouflage uint32 +} + +func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + cbc := &cbcCipher{ + mac: macModes[algs.MAC].new(macKey), + decrypter: cipher.NewCBCDecrypter(c, iv), + encrypter: cipher.NewCBCEncrypter(c, iv), + packetData: make([]byte, 1024), + } + if cbc.mac != nil { + cbc.macSize = uint32(cbc.mac.Size()) + } + + return cbc, nil +} + +func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, iv, key, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func newTripleDESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, iv, key, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func maxUInt32(a, b int) uint32 { + if a > b { + return uint32(a) + } + return uint32(b) +} + +const ( + cbcMinPacketSizeMultiple = 8 + cbcMinPacketSize = 16 + cbcMinPaddingSize = 4 +) + +// cbcError represents a verification error that may leak information. +type cbcError string + +func (e cbcError) Error() string { return string(e) } + +func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { + p, err := c.readPacketLeaky(seqNum, r) + if err != nil { + if _, ok := err.(cbcError); ok { + // Verification error: read a fixed amount of + // data, to make distinguishing between + // failing MAC and failing length check more + // difficult. + io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) + } + } + return p, err +} + +func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { + blockSize := c.decrypter.BlockSize() + + // Read the header, which will include some of the subsequent data in the + // case of block ciphers - this is copied back to the payload later. + // How many bytes of payload/padding will be read with this first read. + firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) + firstBlock := c.packetData[:firstBlockLength] + if _, err := io.ReadFull(r, firstBlock); err != nil { + return nil, err + } + + c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength + + c.decrypter.CryptBlocks(firstBlock, firstBlock) + length := binary.BigEndian.Uint32(firstBlock[:4]) + if length > maxPacket { + return nil, cbcError("ssh: packet too large") + } + if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { + // The minimum size of a packet is 16 (or the cipher block size, whichever + // is larger) bytes. + return nil, cbcError("ssh: packet too small") + } + // The length of the packet (including the length field but not the MAC) must + // be a multiple of the block size or 8, whichever is larger. + if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { + return nil, cbcError("ssh: invalid packet length multiple") + } + + paddingLength := uint32(firstBlock[4]) + if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { + return nil, cbcError("ssh: invalid packet length") + } + + // Positions within the c.packetData buffer: + macStart := 4 + length + paddingStart := macStart - paddingLength + + // Entire packet size, starting before length, ending at end of mac. + entirePacketSize := macStart + c.macSize + + // Ensure c.packetData is large enough for the entire packet data. + if uint32(cap(c.packetData)) < entirePacketSize { + // Still need to upsize and copy, but this should be rare at runtime, only + // on upsizing the packetData buffer. + c.packetData = make([]byte, entirePacketSize) + copy(c.packetData, firstBlock) + } else { + c.packetData = c.packetData[:entirePacketSize] + } + + if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { + return nil, err + } else { + c.oracleCamouflage -= uint32(n) + } + + remainingCrypted := c.packetData[firstBlockLength:macStart] + c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) + + mac := c.packetData[macStart:] + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData[:macStart]) + c.macResult = c.mac.Sum(c.macResult[:0]) + if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { + return nil, cbcError("ssh: MAC failure") + } + } + + return c.packetData[prefixLen:paddingStart], nil +} + +func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) + + // Length of encrypted portion of the packet (header, payload, padding). + // Enforce minimum padding and packet size. + encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) + // Enforce block size. + encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize + + length := encLength - 4 + paddingLength := int(length) - (1 + len(packet)) + + // Overall buffer contains: header, payload, padding, mac. + // Space for the MAC is reserved in the capacity but not the slice length. + bufferSize := encLength + c.macSize + if uint32(cap(c.packetData)) < bufferSize { + c.packetData = make([]byte, encLength, bufferSize) + } else { + c.packetData = c.packetData[:encLength] + } + + p := c.packetData + + // Packet header. + binary.BigEndian.PutUint32(p, length) + p = p[4:] + p[0] = byte(paddingLength) + + // Payload. + p = p[1:] + copy(p, packet) + + // Padding. + p = p[len(packet):] + if _, err := io.ReadFull(rand, p); err != nil { + return err + } + + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData) + // The MAC is now appended into the capacity reserved for it earlier. + c.packetData = c.mac.Sum(c.packetData) + } + + c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) + + if _, err := w.Write(c.packetData); err != nil { + return err + } + + return nil +} diff --git a/vendor/golang.org/x/crypto/ssh/cipher_test.go b/vendor/golang.org/x/crypto/ssh/cipher_test.go new file mode 100644 index 0000000..eced8d8 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/cipher_test.go @@ -0,0 +1,127 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/rand" + "testing" +) + +func TestDefaultCiphersExist(t *testing.T) { + for _, cipherAlgo := range supportedCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("default cipher %q is unknown", cipherAlgo) + } + } +} + +func TestPacketCiphers(t *testing.T) { + // Still test aes128cbc cipher although it's commented out. + cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} + defer delete(cipherModes, aes128cbcID) + + for cipher := range cipherModes { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: cipher, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) + continue + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Errorf("newPacketCipher(client, %q): %v", cipher, err) + continue + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writePacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writePacket(%q): %v", cipher, err) + continue + } + + packet, err := server.readPacket(0, buf) + if err != nil { + t.Errorf("readPacket(%q): %v", cipher, err) + continue + } + + if string(packet) != want { + t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) + } + } +} + +func TestCBCOracleCounterMeasure(t *testing.T) { + cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} + defer delete(cipherModes, aes128cbcID) + + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: aes128cbcID, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writePacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writePacket: %v", err) + } + + packetSize := buf.Len() + buf.Write(make([]byte, 2*maxPacket)) + + // We corrupt each byte, but this usually will only test the + // 'packet too large' or 'MAC failure' cases. + lastRead := -1 + for i := 0; i < packetSize; i++ { + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + fresh := &bytes.Buffer{} + fresh.Write(buf.Bytes()) + fresh.Bytes()[i] ^= 0x01 + + before := fresh.Len() + _, err = server.readPacket(0, fresh) + if err == nil { + t.Errorf("corrupt byte %d: readPacket succeeded ", i) + continue + } + if _, ok := err.(cbcError); !ok { + t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) + continue + } + + after := fresh.Len() + bytesRead := before - after + if bytesRead < maxPacket { + t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) + continue + } + + if i > 0 && bytesRead != lastRead { + t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) + } + lastRead = bytesRead + } +} diff --git a/vendor/golang.org/x/crypto/ssh/client.go b/vendor/golang.org/x/crypto/ssh/client.go new file mode 100644 index 0000000..0212a20 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/client.go @@ -0,0 +1,213 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "net" + "sync" + "time" +) + +// Client implements a traditional SSH client that supports shells, +// subprocesses, port forwarding and tunneled dialing. +type Client struct { + Conn + + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel +} + +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c + } + + ch := c.channelHandlers[channelType] + if ch != nil { + return nil + } + + ch = make(chan NewChannel, 16) + c.channelHandlers[channelType] = ch + return ch +} + +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), + } + + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) + return conn +} + +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + conn := &connection{ + sshConn: sshConn{conn: c}, + } + + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) + } else { + c.clientVersion = []byte(packageVersion) + } + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + return err + } + + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + if err := c.transport.requestInitialKeyChange(); err != nil { + return err + } + + // We just did the key change, so the session ID is established. + c.sessionID = c.transport.getSessionID() + + return c.clientAuthenticate(config) +} + +// verifyHostKeySignature verifies the host key obtained in the key +// exchange. +func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) + if len(rest) > 0 || !ok { + return errors.New("ssh: signature parse error") + } + + return hostKey.Verify(result.H, sig) +} + +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err + } + return newSession(ch, in) +} + +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() + + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) + } + } + + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) + } + c.channelHandlers = nil + c.mu.Unlock() +} + +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil +} + +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. +type ClientConfig struct { + // Config contains configuration that is shared between clients and + // servers. + Config + + // User contains the username to authenticate as. + User string + + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod + + // HostKeyCallback, if not nil, is called during the cryptographic + // handshake to validate the server's host key. A nil HostKeyCallback + // implies that all host keys are accepted. + HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. + ClientVersion string + + // HostKeyAlgorithms lists the key types that the client will + // accept from the server as host key, in order of + // preference. If empty, a reasonable default is used. Any + // string returned from PublicKey.Type method may be used, or + // any of the CertAlgoXxxx and KeyAlgoXxxx constants. + HostKeyAlgorithms []string + + // Timeout is the maximum amount of time for the TCP connection to establish. + // + // A Timeout of zero means no timeout. + Timeout time.Duration +} diff --git a/vendor/golang.org/x/crypto/ssh/client_auth.go b/vendor/golang.org/x/crypto/ssh/client_auth.go new file mode 100644 index 0000000..294af0d --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/client_auth.go @@ -0,0 +1,473 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { + // initiate user auth session + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.transport.readPacket() + if err != nil { + return err + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + + // during the authentication phase the client first attempts the "none" method + // then any untried methods suggested by the server. + tried := make(map[string]bool) + var lastMethods []string + for auth := AuthMethod(new(noneAuth)); auth != nil; { + ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) + if err != nil { + return err + } + if ok { + // success + return nil + } + tried[auth.method()] = true + if methods == nil { + methods = lastMethods + } + lastMethods = methods + + auth = nil + + findNext: + for _, a := range config.Auth { + candidateMethod := a.method() + if tried[candidateMethod] { + continue + } + for _, meth := range methods { + if meth == candidateMethod { + auth = a + break findNext + } + } + } + } + return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) +} + +func keys(m map[string]bool) []string { + s := make([]string, 0, len(m)) + + for key := range m { + s = append(s, key) + } + return s +} + +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { + // auth authenticates user over transport t. + // Returns true if authentication is successful. + // If authentication is not successful, a []string of alternative + // method names is returned. If the slice is nil, it will be ignored + // and the previous set of possible methods will be reused. + auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) + + // method returns the RFC 4252 method name. + method() string +} + +// "none" authentication, RFC 4252 section 5.2. +type noneAuth int + +func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + if err := c.writePacket(Marshal(&userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: "none", + })); err != nil { + return false, nil, err + } + + return handleAuthResponse(c) +} + +func (n *noneAuth) method() string { + return "none" +} + +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) + +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + type passwordAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + Reply bool + Password string + } + + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. + if err != nil { + return false, nil, err + } + + if err := c.writePacket(Marshal(&passwordAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + Reply: false, + Password: pw, + })); err != nil { + return false, nil, err + } + + return handleAuthResponse(c) +} + +func (cb passwordCallback) method() string { + return "password" +} + +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) +} + +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) +} + +type publickeyAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + // HasSig indicates to the receiver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey + Sig []byte `ssh:"rest"` +} + +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + // Authentication is performed in two stages. The first stage sends an + // enquiry to test if each key is acceptable to the remote. The second + // stage attempts to authenticate with the valid keys obtained in the + // first stage. + + signers, err := cb() + if err != nil { + return false, nil, err + } + var validKeys []Signer + for _, signer := range signers { + if ok, err := validateKey(signer.PublicKey(), user, c); ok { + validKeys = append(validKeys, signer) + } else { + if err != nil { + return false, nil, err + } + } + } + + // methods that may continue if this auth is not successful. + var methods []string + for _, signer := range validKeys { + pub := signer.PublicKey() + + pubKey := pub.Marshal() + sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, []byte(pub.Type()), pubKey)) + if err != nil { + return false, nil, err + } + + // manually wrap the serialized signature in a string + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: pub.Type(), + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return false, nil, err + } + var success bool + success, methods, err = handleAuthResponse(c) + if err != nil { + return false, nil, err + } + if success { + return success, methods, err + } + } + return false, methods, nil +} + +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: "publickey", + HasSig: false, + Algoname: key.Type(), + PubKey: pubKey, + } + if err := c.writePacket(Marshal(&msg)); err != nil { + return false, err + } + + return confirmKeyAck(key, c) +} + +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + algoname := key.Type() + + for { + packet, err := c.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + // TODO(gpaul): add callback to present the banner to the user + case msgUserAuthPubKeyOk: + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, err + } + if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) +} + +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) +} + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(c packetConn) (bool, []string, error) { + for { + packet, err := c.readPacket() + if err != nil { + return false, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + // TODO: add callback to present the banner to the user + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + return false, msg.Methods, nil + case msgUserAuthSuccess: + return true, nil, nil + default: + return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the user and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) + +// KeyboardInteractive returns a AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge +} + +func (cb KeyboardInteractiveChallenge) method() string { + return "keyboard-interactive" +} + +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { + type initiateMsg struct { + User string `sshtype:"50"` + Service string + Method string + Language string + Submethods string + } + + if err := c.writePacket(Marshal(&initiateMsg{ + User: user, + Service: serviceSSH, + Method: "keyboard-interactive", + })); err != nil { + return false, nil, err + } + + for { + packet, err := c.readPacket() + if err != nil { + return false, nil, err + } + + // like handleAuthResponse, but with less options. + switch packet[0] { + case msgUserAuthBanner: + // TODO: Print banners during userauth. + continue + case msgUserAuthInfoRequest: + // OK + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + return false, msg.Methods, nil + case msgUserAuthSuccess: + return true, nil, nil + default: + return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + + var msg userAuthInfoRequestMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, nil, err + } + + // Manually unpack the prompt/echo pairs. + rest := msg.Prompts + var prompts []string + var echos []bool + for i := 0; i < int(msg.NumPrompts); i++ { + prompt, r, ok := parseString(rest) + if !ok || len(r) == 0 { + return false, nil, errors.New("ssh: prompt format error") + } + prompts = append(prompts, string(prompt)) + echos = append(echos, r[0] != 0) + rest = r[1:] + } + + if len(rest) != 0 { + return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") + } + + answers, err := cb(msg.User, msg.Instruction, prompts, echos) + if err != nil { + return false, nil, err + } + + if len(answers) != len(prompts) { + return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") + } + responseLength := 1 + 4 + for _, a := range answers { + responseLength += stringLength(len(a)) + } + serialized := make([]byte, responseLength) + p := serialized + p[0] = msgUserAuthInfoResponse + p = p[1:] + p = marshalUint32(p, uint32(len(answers))) + for _, a := range answers { + p = marshalString(p, []byte(a)) + } + + if err := c.writePacket(serialized); err != nil { + return false, nil, err + } + } +} + +type retryableAuthMethod struct { + authMethod AuthMethod + maxTries int +} + +func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok bool, methods []string, err error) { + for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ { + ok, methods, err = r.authMethod.auth(session, user, c, rand) + if ok || err != nil { // either success or error terminate + return ok, methods, err + } + } + return ok, methods, err +} + +func (r *retryableAuthMethod) method() string { + return r.authMethod.method() +} + +// RetryableAuthMethod is a decorator for other auth methods enabling them to +// be retried up to maxTries before considering that AuthMethod itself failed. +// If maxTries is <= 0, will retry indefinitely +// +// This is useful for interactive clients using challenge/response type +// authentication (e.g. Keyboard-Interactive, Password, etc) where the user +// could mistype their response resulting in the server issuing a +// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4 +// [keyboard-interactive]); Without this decorator, the non-retryable +// AuthMethod would be removed from future consideration, and never tried again +// (and so the user would never be able to retry their entry). +func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod { + return &retryableAuthMethod{authMethod: auth, maxTries: maxTries} +} diff --git a/vendor/golang.org/x/crypto/ssh/client_auth_test.go b/vendor/golang.org/x/crypto/ssh/client_auth_test.go new file mode 100644 index 0000000..1409276 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/client_auth_test.go @@ -0,0 +1,472 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "os" + "strings" + "testing" +) + +type keyboardInteractive map[string]string + +func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { + var answers []string + for _, q := range questions { + answers = append(answers, cr[q]) + } + return answers, nil +} + +// reused internally by tests +var clientPassword = "tiger" + +// tryAuth runs a handshake with a given config against an SSH server +// with config serverConfig +func tryAuth(t *testing.T, config *ClientConfig) error { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + certChecker := CertChecker{ + IsAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, + } + + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + PublicKeyCallback: certChecker.Authenticate, + KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { + ans, err := challenge("user", + "instruction", + []string{"question1", "question2"}, + []bool{true, true}) + if err != nil { + return nil, err + } + ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" + if ok { + challenge("user", "motd", nil, nil) + return nil, nil + } + return nil, errors.New("keyboard-interactive failed") + }, + AuthLogCallback: func(conn ConnMetadata, method string, err error) { + t.Logf("user %q, method %q: %v", conn.User(), method, err) + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + go newServer(c1, serverConfig) + _, _, _, err = NewClientConn(c2, "", config) + return err +} + +func TestClientAuthPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password(clientPassword), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodFallback(t *testing.T) { + var passwordCalled bool + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PasswordCallback( + func() (string, error) { + passwordCalled = true + return "WRONG", nil + }), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + if passwordCalled { + t.Errorf("password auth tried before public-key auth.") + } +} + +func TestAuthMethodWrongPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password("wrong"), + PublicKeys(testSigners["rsa"]), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "answer2", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "WRONG", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") + } +} + +// the mock server will only authenticate ssh-rsa keys +func TestAuthMethodInvalidPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"]), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("dsa private key should not have authenticated with rsa public key") + } +} + +// the client should authenticate with the second key +func TestAuthMethodRSAandDSA(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"], testSigners["rsa"]), + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with rsa key: %v", err) + } +} + +func TestClientHMAC(t *testing.T) { + for _, mac := range supportedMACs { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + Config: Config{ + MACs: []string{mac}, + }, + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) + } + } +} + +// issue 4285. +func TestClientUnsupportedCipher(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + Ciphers: []string{"aes128-cbc"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil { + t.Errorf("expected no ciphers in common") + } +} + +func TestClientUnsupportedKex(t *testing.T) { + if os.Getenv("GO_BUILDER_NAME") != "" { + t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198") + } + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { + t.Errorf("got %v, expected 'common algorithm'", err) + } +} + +func TestClientLoginCert(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + } + clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) + + t.Log("should succeed") + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("corrupted signature") + cert.Signature.Blob[0]++ + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with corrupted sig") + } + + t.Log("revoked") + cert.Serial = 666 + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("revoked cert login succeeded") + } + cert.Serial = 1 + + t.Log("sign with wrong key") + cert.SignCert(rand.Reader, testSigners["dsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with non-authoritative key") + } + + t.Log("host cert") + cert.CertType = HostCert + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong type") + } + cert.CertType = UserCert + + t.Log("principal specified") + cert.ValidPrincipals = []string{"user"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + t.Log("wrong principal specified") + cert.ValidPrincipals = []string{"fred"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong principal") + } + cert.ValidPrincipals = nil + + t.Log("added critical option") + cert.CriticalOptions = map[string]string{"root-access": "yes"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with unrecognized critical option") + } + + t.Log("allowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login with source-address failed: %v", err) + } + + t.Log("disallowed source address") + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login with source-address succeeded") + } +} + +func testPermissionsPassing(withPermissions bool, t *testing.T) { + serverConfig := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "nopermissions" { + return nil, nil + } else { + return &Permissions{}, nil + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + } + if withPermissions { + clientConfig.User = "permissions" + } else { + clientConfig.User = "nopermissions" + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatal(err) + } + if p := serverConn.Permissions; (p != nil) != withPermissions { + t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) + } +} + +func TestPermissionsPassing(t *testing.T) { + testPermissionsPassing(true, t) +} + +func TestNoPermissionsPassing(t *testing.T) { + testPermissionsPassing(false, t) +} + +func TestRetryableAuth(t *testing.T) { + n := 0 + passwords := []string{"WRONG1", "WRONG2"} + + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 2), + PublicKeys(testSigners["rsa"]), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if n != 2 { + t.Fatalf("Did not try all passwords") + } +} + +func ExampleRetryableAuthMethod(t *testing.T) { + user := "testuser" + NumberOfPrompts := 3 + + // Normally this would be a callback that prompts the user to answer the + // provided questions + Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return []string{"answer1", "answer2"}, nil + } + + config := &ClientConfig{ + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), + }, + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +// Test if username is received on server side when NoClientAuth is used +func TestClientAuthNone(t *testing.T) { + user := "testuser" + serverConfig := &ServerConfig{ + NoClientAuth: true, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + User: user, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatalf("newServer: %v", err) + } + if serverConn.User() != user { + t.Fatalf("server: got %q, want %q", serverConn.User(), user) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/client_test.go b/vendor/golang.org/x/crypto/ssh/client_test.go new file mode 100644 index 0000000..1fe790c --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/client_test.go @@ -0,0 +1,39 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "net" + "testing" +) + +func testClientVersion(t *testing.T, config *ClientConfig, expected string) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + receivedVersion := make(chan string, 1) + go func() { + version, err := readVersion(serverConn) + if err != nil { + receivedVersion <- "" + } else { + receivedVersion <- string(version) + } + serverConn.Close() + }() + NewClientConn(clientConn, "", config) + actual := <-receivedVersion + if actual != expected { + t.Fatalf("got %s; want %s", actual, expected) + } +} + +func TestCustomClientVersion(t *testing.T) { + version := "Test-Client-Version-0.0" + testClientVersion(t, &ClientConfig{ClientVersion: version}, version) +} + +func TestDefaultClientVersion(t *testing.T) { + testClientVersion(t, &ClientConfig{}, packageVersion) +} diff --git a/vendor/golang.org/x/crypto/ssh/common.go b/vendor/golang.org/x/crypto/ssh/common.go new file mode 100644 index 0000000..2c72ab5 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/common.go @@ -0,0 +1,356 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/rand" + "fmt" + "io" + "sync" + + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +// These are string constants in the SSH protocol. +const ( + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// supportedCiphers specifies the supported ciphers in preference order. +var supportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", + "arcfour256", "arcfour128", +} + +// supportedKexAlgos specifies the supported key-exchange algorithms in +// preference order. +var supportedKexAlgos = []string{ + kexAlgoCurve25519SHA256, + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA1, kexAlgoDH1SHA1, +} + +// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. +var supportedHostKeyAlgos = []string{ + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01, + + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSA, KeyAlgoDSA, + + KeyAlgoED25519, +} + +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", +} + +var supportedCompressions = []string{compressionNone} + +// hashFuncs keeps the mapping of supported algorithms to their respective +// hashes needed for signature verification. +var hashFuncs = map[string]crypto.Hash{ + KeyAlgoRSA: crypto.SHA1, + KeyAlgoDSA: crypto.SHA1, + KeyAlgoECDSA256: crypto.SHA256, + KeyAlgoECDSA384: crypto.SHA384, + KeyAlgoECDSA521: crypto.SHA512, + CertAlgoRSAv01: crypto.SHA1, + CertAlgoDSAv01: crypto.SHA1, + CertAlgoECDSA256v01: crypto.SHA256, + CertAlgoECDSA384v01: crypto.SHA384, + CertAlgoECDSA521v01: crypto.SHA512, +} + +// unexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) +} + +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) +} + +func findCommon(what string, client []string, server []string) (common string, err error) { + for _, c := range client { + for _, s := range server { + if c == s { + return c, nil + } + } + } + return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) +} + +type directionAlgorithms struct { + Cipher string + MAC string + Compression string +} + +type algorithms struct { + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms +} + +func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { + result := &algorithms{} + + result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if err != nil { + return + } + + result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if err != nil { + return + } + + result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if err != nil { + return + } + + result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if err != nil { + return + } + + result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if err != nil { + return + } + + result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if err != nil { + return + } + + result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if err != nil { + return + } + + result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if err != nil { + return + } + + return result, nil +} + +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, 1 gigabyte is used. + RekeyThreshold uint64 + + // The allowed key exchanges algorithms. If unspecified then a + // default set of algorithms is used. + KeyExchanges []string + + // The allowed cipher algorithms. If unspecified then a sensible + // default is used. + Ciphers []string + + // The allowed MAC algorithms. If unspecified then a sensible default + // is used. + MACs []string +} + +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } + if c.Ciphers == nil { + c.Ciphers = supportedCiphers + } + var ciphers []string + for _, c := range c.Ciphers { + if cipherModes[c] != nil { + // reject the cipher if we have no cipherModes definition + ciphers = append(ciphers, c) + } + } + c.Ciphers = ciphers + + if c.KeyExchanges == nil { + c.KeyExchanges = supportedKexAlgos + } + + if c.MACs == nil { + c.MACs = supportedMACs + } + + if c.RekeyThreshold == 0 { + // RFC 4253, section 9 suggests rekeying after 1G. + c.RekeyThreshold = 1 << 30 + } + if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold + } +} + +// buildDataSignedForAuth returns the data that is signed in order to prove +// possession of a private key. See RFC 4252, section 7. +func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo []byte + PubKey []byte + }{ + sessionId, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, + } + return Marshal(data) +} + +func appendU16(buf []byte, n uint16) []byte { + return append(buf, byte(n>>8), byte(n)) +} + +func appendU32(buf []byte, n uint32) []byte { + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendInt(buf []byte, n int) []byte { + return appendU32(buf, uint32(n)) +} + +func appendString(buf []byte, s string) []byte { + buf = appendU32(buf, uint32(len(s))) + buf = append(buf, s...) + return buf +} + +func appendBool(buf []byte, b bool) []byte { + if b { + return append(buf, 1) + } + return append(buf, 0) +} + +// newCond is a helper to hide the fact that there is no usable zero +// value for sync.Cond. +func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} diff --git a/vendor/golang.org/x/crypto/ssh/connection.go b/vendor/golang.org/x/crypto/ssh/connection.go new file mode 100644 index 0000000..e786f2f --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/connection.go @@ -0,0 +1,143 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + User() string + + // SessionID returns the sesson hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the server's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshconn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} diff --git a/vendor/golang.org/x/crypto/ssh/doc.go b/vendor/golang.org/x/crypto/ssh/doc.go new file mode 100644 index 0000000..d6be894 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/doc.go @@ -0,0 +1,18 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH client and server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +References: + [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD + [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 +*/ +package ssh // import "golang.org/x/crypto/ssh" diff --git a/vendor/golang.org/x/crypto/ssh/example_test.go b/vendor/golang.org/x/crypto/ssh/example_test.go new file mode 100644 index 0000000..4d2eabd --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/example_test.go @@ -0,0 +1,262 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh_test + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +func ExampleNewServerConn() { + // Public key authentication is done by comparing + // the public key of a received connection + // with the entries in the authorized_keys file. + authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys") + if err != nil { + log.Fatalf("Failed to load authorized_keys, err: %v", err) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + log.Fatal(err) + } + + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + // Remove to disable password auth. + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + + // Remove to disable public key auth. + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if authorizedKeysMap[string(pubKey.Marshal())] { + return nil, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + + privateBytes, err := ioutil.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake: ", err) + } + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + go func(in <-chan *ssh.Request) { + for req := range in { + req.Reply(req.Type == "shell", nil) + } + }(requests) + + term := terminal.NewTerminal(channel, "> ") + + go func() { + defer channel.Close() + for { + line, err := term.ReadLine() + if err != nil { + break + } + fmt.Println(line) + } + }() + } +} + +func ExampleDial() { + // An SSH client is represented with a ClientConn. + // + // To authenticate with the remote server you must pass at least one + // implementation of AuthMethod via the Auth field in ClientConfig. + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("yourpassword"), + }, + } + client, err := ssh.Dial("tcp", "yourserver.com:22", config) + if err != nil { + log.Fatal("Failed to dial: ", err) + } + + // Each ClientConn can support multiple interactive sessions, + // represented by a Session. + session, err := client.NewSession() + if err != nil { + log.Fatal("Failed to create session: ", err) + } + defer session.Close() + + // Once a Session is created, you can execute a single command on + // the remote side using the Run method. + var b bytes.Buffer + session.Stdout = &b + if err := session.Run("/usr/bin/whoami"); err != nil { + log.Fatal("Failed to run: " + err.Error()) + } + fmt.Println(b.String()) +} + +func ExamplePublicKeys() { + // A public key may be used to authenticate against the remote + // server by using an unencrypted PEM-encoded private key file. + // + // If you have an encrypted private key, the crypto/x509 package + // can be used to decrypt it. + key, err := ioutil.ReadFile("/home/user/.ssh/id_rsa") + if err != nil { + log.Fatalf("unable to read private key: %v", err) + } + + // Create the Signer for this private key. + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + log.Fatalf("unable to parse private key: %v", err) + } + + config := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + // Use the PublicKeys method for remote authentication. + ssh.PublicKeys(signer), + }, + } + + // Connect to the remote server and perform the SSH handshake. + client, err := ssh.Dial("tcp", "host.com:22", config) + if err != nil { + log.Fatalf("unable to connect: %v", err) + } + defer client.Close() +} + +func ExampleClient_Listen() { + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + } + // Dial your ssh server. + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + + // Request the remote side to open port 8080 on all interfaces. + l, err := conn.Listen("tcp", "0.0.0.0:8080") + if err != nil { + log.Fatal("unable to register tcp forward: ", err) + } + defer l.Close() + + // Serve HTTP with your SSH server acting as a reverse proxy. + http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + fmt.Fprintf(resp, "Hello world!\n") + })) +} + +func ExampleSession_RequestPty() { + // Create client config + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + } + // Connect to ssh server + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + // Create a session + session, err := conn.NewSession() + if err != nil { + log.Fatal("unable to create session: ", err) + } + defer session.Close() + // Set up terminal modes + modes := ssh.TerminalModes{ + ssh.ECHO: 0, // disable echoing + ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud + ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + } + // Request pseudo terminal + if err := session.RequestPty("xterm", 40, 80, modes); err != nil { + log.Fatal("request for pseudo terminal failed: ", err) + } + // Start remote shell + if err := session.Shell(); err != nil { + log.Fatal("failed to start shell: ", err) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/handshake.go b/vendor/golang.org/x/crypto/ssh/handshake.go new file mode 100644 index 0000000..37d42e4 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/handshake.go @@ -0,0 +1,460 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + // hostKeys is non-empty if we are the server. In that case, + // it contains all host keys that can be used to sign the + // connection. + hostKeys []Signer + + // hostKeyAlgorithms is non-empty if we are the client. In that case, + // we accept these key types from the server as host key. + hostKeyAlgorithms []string + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + // data for host key checking + hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + dialAddress string + remoteAddr net.Addr + + readSinceKex uint64 + + // Protects the writing side of the connection + mu sync.Mutex + cond *sync.Cond + sentInitPacket []byte + sentInitMsg *kexInitMsg + writtenSinceKex uint64 + writeError error + + // The session ID or nil if first kex did not complete yet. + sessionID []byte +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, 16), + config: config, + } + t.cond = sync.NewCond(&t.mu) + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + if config.HostKeyAlgorithms != nil { + t.hostKeyAlgorithms = config.HostKeyAlgorithms + } else { + t.hostKeyAlgorithms = supportedHostKeyAlgos + } + go t.readLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + go t.readLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.sessionID +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + for { + p, err := t.readOnePacket() + if err != nil { + t.readError = err + close(t.incoming) + break + } + if p[0] == msgIgnore || p[0] == msgDebug { + continue + } + t.incoming <- p + } + + // If we can't read, declare the writing part dead too. + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError == nil { + t.writeError = t.readError + } + t.cond.Broadcast() +} + +func (t *handshakeTransport) readOnePacket() ([]byte, error) { + if t.readSinceKex > t.config.RekeyThreshold { + if err := t.requestKeyChange(); err != nil { + return nil, err + } + } + + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + t.readSinceKex += uint64(len(p)) + if debugHandshake { + if p[0] == msgChannelData || p[0] == msgChannelExtendedData { + log.Printf("%s got data (packet %d bytes)", t.id(), len(p)) + } else { + msg, err := decode(p) + log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) + } + } + if p[0] != msgKexInit { + return p, nil + } + + t.mu.Lock() + + firstKex := t.sessionID == nil + + err = t.enterKeyExchangeLocked(p) + if err != nil { + // drop connection + t.conn.Close() + t.writeError = err + } + + if debugHandshake { + log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) + } + + // Unblock writers. + t.sentInitMsg = nil + t.sentInitPacket = nil + t.cond.Broadcast() + t.writtenSinceKex = 0 + t.mu.Unlock() + + if err != nil { + return nil, err + } + + t.readSinceKex = 0 + + // By default, a key exchange is hidden from higher layers by + // translating it into msgIgnore. + successPacket := []byte{msgIgnore} + if firstKex { + // sendKexInit() for the first kex waits for + // msgNewKeys so the authentication process is + // guaranteed to happen over an encrypted transport. + successPacket = []byte{msgNewKeys} + } + + return successPacket, nil +} + +// keyChangeCategory describes whether a key exchange is the first on a +// connection, or a subsequent one. +type keyChangeCategory bool + +const ( + firstKeyExchange keyChangeCategory = true + subsequentKeyExchange keyChangeCategory = false +) + +// sendKexInit sends a key change message, and returns the message +// that was sent. After initiating the key change, all writes will be +// blocked until the change is done, and a failed key change will +// close the underlying transport. This function is safe for +// concurrent use by multiple goroutines. +func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error { + var err error + + t.mu.Lock() + // If this is the initial key change, but we already have a sessionID, + // then do nothing because the key exchange has already completed + // asynchronously. + if !isFirst || t.sessionID == nil { + _, _, err = t.sendKexInitLocked(isFirst) + } + t.mu.Unlock() + if err != nil { + return err + } + if isFirst { + if packet, err := t.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + } + return nil +} + +func (t *handshakeTransport) requestInitialKeyChange() error { + return t.sendKexInit(firstKeyExchange) +} + +func (t *handshakeTransport) requestKeyChange() error { + return t.sendKexInit(subsequentKeyExchange) +} + +// sendKexInitLocked sends a key change message. t.mu must be locked +// while this happens. +func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + if t.sentInitMsg != nil { + return t.sentInitMsg, t.sentInitPacket, nil + } + + msg := &kexInitMsg{ + KexAlgos: t.config.KeyExchanges, + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + if len(t.hostKeys) > 0 { + for _, k := range t.hostKeys { + msg.ServerHostKeyAlgos = append( + msg.ServerHostKeyAlgos, k.PublicKey().Type()) + } + } else { + msg.ServerHostKeyAlgos = t.hostKeyAlgorithms + } + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.conn.writePacket(packetCopy); err != nil { + return nil, nil, err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + return msg, packet, nil +} + +func (t *handshakeTransport) writePacket(p []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.writtenSinceKex > t.config.RekeyThreshold { + t.sendKexInitLocked(subsequentKeyExchange) + } + for t.sentInitMsg != nil && t.writeError == nil { + t.cond.Wait() + } + if t.writeError != nil { + return t.writeError + } + t.writtenSinceKex += uint64(len(p)) + + switch p[0] { + case msgKexInit: + return errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + return errors.New("ssh: only handshakeTransport can send newKeys") + default: + return t.conn.writePacket(p) + } +} + +func (t *handshakeTransport) Close() error { + return t.conn.Close() +} + +// enterKeyExchange runs the key exchange. t.mu must be held while running this. +func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange) + if err != nil { + return err + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: myInitPacket, + } + + clientInit := otherInit + serverInit := myInit + if len(t.hostKeys) == 0 { + clientInit = myInit + serverInit = otherInit + + magics.clientKexInit = myInitPacket + magics.serverKexInit = otherInitPacket + } + + algs, err := findAgreedAlgorithms(clientInit, serverInit) + if err != nil { + return err + } + + // We don't send FirstKexFollows, but we handle receiving it. + // + // RFC 4253 section 7 defines the kex and the agreement method for + // first_kex_packet_follows. It states that the guessed packet + // should be ignored if the "kex algorithm and/or the host + // key algorithm is guessed wrong (server and client have + // different preferred algorithm), or if any of the other + // algorithms cannot be agreed upon". The other algorithms have + // already been checked above so the kex algorithm and host key + // algorithm are checked here. + if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[algs.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, algs, &magics) + } else { + result, err = t.client(kex, algs, &magics) + } + + if err != nil { + return err + } + + if t.sessionID == nil { + t.sessionID = result.H + } + result.SessionID = t.sessionID + + t.conn.prepareKeyChange(algs, result) + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + var hostKey Signer + for _, k := range t.hostKeys { + if algs.hostKey == k.PublicKey().Type() { + hostKey = k + } + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, result); err != nil { + return nil, err + } + + if t.hostKeyCallback != nil { + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + } + + return result, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/handshake_test.go b/vendor/golang.org/x/crypto/ssh/handshake_test.go new file mode 100644 index 0000000..da53d3a --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/handshake_test.go @@ -0,0 +1,486 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +type testChecker struct { + calls []string +} + +func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + if dialAddr == "bad" { + return fmt.Errorf("dialAddr is bad") + } + + if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { + return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) + } + + t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) + + return nil +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { + a, b, err := netPipe() + if err != nil { + return nil, nil, err + } + + trC := newTransport(a, rand.Reader, true) + trS := newTransport(b, rand.Reader, false) + clientConf.SetDefaults() + + v := []byte("version") + client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server = newServerTransport(trS, v, v, serverConf) + + return client, server, nil +} + +func TestHandshakeBasic(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + go func() { + // Client writes a bunch of stuff, and does a key + // change in the middle. This should not confuse the + // handshake in progress + for i := 0; i < 10; i++ { + p := []byte{msgRequestSuccess, byte(i)} + if err := trC.writePacket(p); err != nil { + t.Fatalf("sendPacket: %v", err) + } + if i == 5 { + // halfway through, we request a key change. + err := trC.sendKexInit(subsequentKeyExchange) + if err != nil { + t.Fatalf("sendKexInit: %v", err) + } + } + } + trC.Close() + }() + + // Server checks that client messages come in cleanly + i := 0 + for { + p, err := trS.readPacket() + if err != nil { + break + } + if p[0] == msgNewKeys { + continue + } + want := []byte{msgRequestSuccess, byte(i)} + if bytes.Compare(p, want) != 0 { + t.Errorf("message %d: got %q, want %q", i, p, want) + } + i++ + } + if i != 10 { + t.Errorf("received %d messages, want 10.", i) + } + + // If all went well, we registered exactly 1 key change. + if len(checker.calls) != 1 { + t.Fatalf("got %d host key checks, want 1", len(checker.calls)) + } + + pub := testSigners["ecdsa"].PublicKey() + want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) + if want != checker.calls[0] { + t.Errorf("got %q want %q for host key check", checker.calls[0], want) + } +} + +func TestHandshakeError(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + // send a packet + packet := []byte{msgRequestSuccess, 42} + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + + // Now request a key change. + err = trC.sendKexInit(subsequentKeyExchange) + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + // the key change will fail, and afterwards we can't write. + if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { + t.Errorf("writePacket after botched rekey succeeded.") + } + + readback, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + if bytes.Compare(readback, packet) != 0 { + t.Errorf("got %q want %q", readback, packet) + } + readback, err = trS.readPacket() + if err == nil { + t.Errorf("got a message %q after failed key change", readback) + } +} + +func TestForceFirstKex(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) + + // We setup the initial key exchange, but the remote side + // tries to send serviceRequestMsg in cleartext, which is + // disallowed. + + err = trS.sendKexInit(firstKeyExchange) + if err == nil { + t.Errorf("server first kex init should reject unexpected packet") + } +} + +func TestHandshakeTwice(t *testing.T) { + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + // Both sides should ask for the first key exchange first. + err = trS.sendKexInit(firstKeyExchange) + if err != nil { + t.Errorf("server sendKexInit: %v", err) + } + + err = trC.sendKexInit(firstKeyExchange) + if err != nil { + t.Errorf("client sendKexInit: %v", err) + } + + sent := 0 + // send a packet + packet := make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + sent++ + + // Send another packet. Use a fresh one, since writePacket destroys. + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + sent++ + + // 2nd key change. + err = trC.sendKexInit(subsequentKeyExchange) + if err != nil { + t.Errorf("sendKexInit: %v", err) + } + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + sent++ + + packet = make([]byte, 5) + packet[0] = msgRequestSuccess + for i := 0; i < sent; i++ { + msg, err := trS.readPacket() + if err != nil { + t.Fatalf("server closed too soon: %v", err) + } + + if bytes.Compare(msg, packet) != 0 { + t.Errorf("packet %d: got %q want %q", i, msg, packet) + } + } + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, want 2", len(checker.calls)) + } +} + +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &testChecker{} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + for i := 0; i < 5; i++ { + packet := make([]byte, 251) + packet[0] = msgRequestSuccess + if err := trC.writePacket(packet); err != nil { + t.Errorf("writePacket: %v", err) + } + } + + j := 0 + for ; j < 5; j++ { + _, err := trS.readPacket() + if err != nil { + break + } + } + + if j != 5 { + t.Errorf("got %d, want 5 messages", j) + } + + if len(checker.calls) != 2 { + t.Errorf("got %d key changes, wanted 2", len(checker.calls)) + } +} + +type syncChecker struct { + called chan int +} + +func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + t.called <- 1 + return nil +} + +func TestHandshakeAutoRekeyRead(t *testing.T) { + sync := &syncChecker{make(chan int, 2)} + clientConf := &ClientConfig{ + HostKeyCallback: sync.Check, + } + clientConf.RekeyThreshold = 500 + + trC, trS, err := handshakePair(clientConf, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet := make([]byte, 501) + packet[0] = msgRequestSuccess + if err := trS.writePacket(packet); err != nil { + t.Fatalf("writePacket: %v", err) + } + // While we read out the packet, a key change will be + // initiated. + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + <-sync.called +} + +// errorKeyingTransport generates errors after a given number of +// read/write operations. +type errorKeyingTransport struct { + packetConn + readLeft, writeLeft int +} + +func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { + return nil +} +func (n *errorKeyingTransport) getSessionID() []byte { + return nil +} + +func (n *errorKeyingTransport) writePacket(packet []byte) error { + if n.writeLeft == 0 { + n.Close() + return errors.New("barf") + } + + n.writeLeft-- + return n.packetConn.writePacket(packet) +} + +func (n *errorKeyingTransport) readPacket() ([]byte, error) { + if n.readLeft == 0 { + n.Close() + return nil, errors.New("barf") + } + + n.readLeft-- + return n.packetConn.readPacket() +} + +func TestHandshakeErrorHandlingRead(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1) + } +} + +func TestHandshakeErrorHandlingWrite(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i) + } +} + +// testHandshakeErrorHandlingN runs handshakes, injecting errors. If +// handshakeTransport deadlocks, the go runtime will detect it and +// panic. +func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { + msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) + + a, b := memPipe() + defer a.Close() + defer b.Close() + + key := testSigners["ecdsa"] + serverConf := Config{RekeyThreshold: minRekeyThreshold} + serverConf.SetDefaults() + serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) + serverConn.hostKeys = []Signer{key} + go serverConn.readLoop() + + clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} + clientConf.SetDefaults() + clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) + clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} + go clientConn.readLoop() + + var wg sync.WaitGroup + wg.Add(4) + + for _, hs := range []packetConn{serverConn, clientConn} { + go func(c packetConn) { + for { + err := c.writePacket(msg) + if err != nil { + break + } + } + wg.Done() + }(hs) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + } + wg.Done() + }(hs) + } + + wg.Wait() +} + +func TestDisconnect(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + errMsg := &disconnectMsg{ + Reason: 42, + Message: "such is life", + } + trC.writePacket(Marshal(errMsg)) + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + + packet, err := trS.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + if packet[0] != msgRequestSuccess { + t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 2 succeeded") + } else if !reflect.DeepEqual(err, errMsg) { + t.Errorf("got error %#v, want %#v", err, errMsg) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 3 succeeded") + } +} diff --git a/vendor/golang.org/x/crypto/ssh/kex.go b/vendor/golang.org/x/crypto/ssh/kex.go new file mode 100644 index 0000000..c87fbeb --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/kex.go @@ -0,0 +1,540 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/subtle" + "errors" + "io" + "math/big" + + "golang.org/x/crypto/curve25519" +) + +const ( + kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + kexAlgoECDH256 = "ecdh-sha2-nistp256" + kexAlgoECDH384 = "ecdh-sha2-nistp384" + kexAlgoECDH521 = "ecdh-sha2-nistp521" + kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" +) + +// kexResult captures the outcome of a key exchange. +type kexResult struct { + // Session hash. See also RFC 4253, section 8. + H []byte + + // Shared secret. See also RFC 4253, section 8. + K []byte + + // Host key as hashed into H. + HostKey []byte + + // Signature of H. + Signature []byte + + // A cryptographic hash function that matches the security + // level of the key exchange algorithm. It is used for + // calculating H, and for deriving keys from H and K. + Hash crypto.Hash + + // The session ID, which is the first H computed. This is used + // to derive key material inside the transport. + SessionID []byte +} + +// handshakeMagics contains data that is always included in the +// session hash. +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +func (m *handshakeMagics) write(w io.Writer) { + writeString(w, m.clientVersion) + writeString(w, m.serverVersion) + writeString(w, m.clientKexInit) + writeString(w, m.serverKexInit) +} + +// kexAlgorithm abstracts different key exchange algorithms. +type kexAlgorithm interface { + // Server runs server-side key agreement, signing the result + // with a hostkey. + Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) + + // Client runs the client-side key agreement. Caller is + // responsible for verifying the host key signature. + Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p, pMinus1 *big.Int +} + +func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil +} + +func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + hashFunc := crypto.SHA1 + + var x *big.Int + for { + var err error + if x, err = rand.Int(randSource, group.pMinus1); err != nil { + return nil, err + } + if x.Sign() > 0 { + break + } + } + + X := new(big.Int).Exp(group.g, x, group.p) + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHReply kexDHReplyMsg + if err = Unmarshal(packet, &kexDHReply); err != nil { + return nil, err + } + + kInt, err := group.diffieHellman(kexDHReply.Y, x) + if err != nil { + return nil, err + } + + h := hashFunc.New() + magics.write(h) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHReply.HostKey, + Signature: kexDHReply.Signature, + Hash: crypto.SHA1, + }, nil +} + +func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + hashFunc := crypto.SHA1 + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = Unmarshal(packet, &kexDHInit); err != nil { + return + } + + var y *big.Int + for { + if y, err = rand.Int(randSource, group.pMinus1); err != nil { + return + } + if y.Sign() > 0 { + break + } + } + + Y := new(big.Int).Exp(group.g, y, group.p) + kInt, err := group.diffieHellman(kexDHInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H) + if err != nil { + return nil, err + } + + kexDHReply := kexDHReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHReply) + + err = c.writePacket(packet) + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA1, + }, nil +} + +// ecdh performs Elliptic Curve Diffie-Hellman key exchange as +// described in RFC 5656, section 4. +type ecdh struct { + curve elliptic.Curve +} + +func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + kexInit := kexECDHInitMsg{ + ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), + } + + serialized := Marshal(&kexInit) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) + if err != nil { + return nil, err + } + + // generate shared secret + secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kexInit.ClientPubKey) + writeString(h, reply.EphemeralPubKey) + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: ecHash(kex.curve), + }, nil +} + +// unmarshalECKey parses and checks an EC key. +func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { + x, y = elliptic.Unmarshal(curve, pubkey) + if x == nil { + return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") + } + if !validateECPublicKey(curve, x, y) { + return nil, nil, errors.New("ssh: public key not on curve") + } + return x, y, nil +} + +// validateECPublicKey checks that the point is a valid public key for +// the given curve. See [SEC1], 3.2.2 +func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { + if x.Sign() == 0 && y.Sign() == 0 { + return false + } + + if x.Cmp(curve.Params().P) >= 0 { + return false + } + + if y.Cmp(curve.Params().P) >= 0 { + return false + } + + if !curve.IsOnCurve(x, y) { + return false + } + + // We don't check if N * PubKey == 0, since + // + // - the NIST curves have cofactor = 1, so this is implicit. + // (We don't foresee an implementation that supports non NIST + // curves) + // + // - for ephemeral keys, we don't need to worry about small + // subgroup attacks. + return true +} + +func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexECDHInit kexECDHInitMsg + if err = Unmarshal(packet, &kexECDHInit); err != nil { + return nil, err + } + + clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) + if err != nil { + return nil, err + } + + // We could cache this key across multiple users/multiple + // connection attempts, but the benefit is small. OpenSSH + // generates a new key for each incoming connection. + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) + + // generate shared secret + secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexECDHInit.ClientPubKey) + writeString(h, serializedEphKey) + + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, rand, H) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: serializedEphKey, + HostKey: hostKeyBytes, + Signature: sig, + } + + serialized := Marshal(&reply) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + return &kexResult{ + H: H, + K: K, + HostKey: reply.HostKey, + Signature: sig, + Hash: ecHash(kex.curve), + }, nil +} + +var kexAlgoMap = map[string]kexAlgorithm{} + +func init() { + // This is the group called diffie-hellman-group1-sha1 in RFC + // 4253 and Oakley Group 2 in RFC 2409. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + + kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} + kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} + kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} + kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} +} + +// curve25519sha256 implements the curve25519-sha256@libssh.org key +// agreement protocol, as described in +// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt +type curve25519sha256 struct{} + +type curve25519KeyPair struct { + priv [32]byte + pub [32]byte +} + +func (kp *curve25519KeyPair) generate(rand io.Reader) error { + if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { + return err + } + curve25519.ScalarBaseMult(&kp.pub, &kp.priv) + return nil +} + +// curve25519Zeros is just an array of 32 zero bytes so that we have something +// convenient to compare against in order to reject curve25519 points with the +// wrong order. +var curve25519Zeros [32]byte + +func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + if len(reply.EphemeralPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var servPub, secret [32]byte + copy(servPub[:], reply.EphemeralPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &servPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kp.pub[:]) + writeString(h, reply.EphemeralPubKey) + + kInt := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: crypto.SHA256, + }, nil +} + +func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexInit kexECDHInitMsg + if err = Unmarshal(packet, &kexInit); err != nil { + return + } + + if len(kexInit.ClientPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + + var clientPub, secret [32]byte + copy(clientPub[:], kexInit.ClientPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &clientPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexInit.ClientPubKey) + writeString(h, kp.pub[:]) + + kInt := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + sig, err := signAndMarshal(priv, rand, H) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: kp.pub[:], + HostKey: hostKeyBytes, + Signature: sig, + } + if err := c.writePacket(Marshal(&reply)); err != nil { + return nil, err + } + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA256, + }, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/kex_test.go b/vendor/golang.org/x/crypto/ssh/kex_test.go new file mode 100644 index 0000000..12ca0ac --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/kex_test.go @@ -0,0 +1,50 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Key exchange tests. + +import ( + "crypto/rand" + "reflect" + "testing" +) + +func TestKexes(t *testing.T) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + a, b := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + go func() { + r, e := kex.Client(a, rand.Reader, &magics) + a.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) + b.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + if clientRes.err != nil { + t.Errorf("client: %v", clientRes.err) + } + if serverRes.err != nil { + t.Errorf("server: %v", serverRes.err) + } + if !reflect.DeepEqual(clientRes.result, serverRes.result) { + t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) + } + } +} diff --git a/vendor/golang.org/x/crypto/ssh/keys.go b/vendor/golang.org/x/crypto/ssh/keys.go new file mode 100644 index 0000000..21f7d0d --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/keys.go @@ -0,0 +1,905 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/md5" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "strings" + + "golang.org/x/crypto/ed25519" +) + +// These constants represent the algorithm names for key types supported by this +// package. +const ( + KeyAlgoRSA = "ssh-rsa" + KeyAlgoDSA = "ssh-dss" + KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" + KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" + KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" + KeyAlgoED25519 = "ssh-ed25519" +) + +// parsePubKey parses a public key of the given algorithm. +// Use ParsePublicKey for keys with prepended algorithm. +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { + switch algo { + case KeyAlgoRSA: + return parseRSA(in) + case KeyAlgoDSA: + return parseDSA(in) + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + return parseECDSA(in) + case KeyAlgoED25519: + return parseED25519(in) + case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01: + cert, err := parseCert(in, certToPrivAlgo(algo)) + if err != nil { + return nil, nil, err + } + return cert, nil, nil + } + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo) +} + +// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format +// (see sshd(8) manual page) once the options and key type fields have been +// removed. +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { + in = bytes.TrimSpace(in) + + i := bytes.IndexAny(in, " \t") + if i == -1 { + i = len(in) + } + base64Key := in[:i] + + key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) + n, err := base64.StdEncoding.Decode(key, base64Key) + if err != nil { + return nil, "", err + } + key = key[:n] + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err + } + comment = string(bytes.TrimSpace(in[i:])) + return out, comment, nil +} + +// ParseKnownHosts parses an entry in the format of the known_hosts file. +// +// The known_hosts format is documented in the sshd(8) manual page. This +// function will parse a single entry from in. On successful return, marker +// will contain the optional marker value (i.e. "cert-authority" or "revoked") +// or else be empty, hosts will contain the hosts that this entry matches, +// pubKey will contain the public key and comment will contain any trailing +// comment at the end of the line. See the sshd(8) manual page for the various +// forms that a host string can take. +// +// The unparsed remainder of the input will be returned in rest. This function +// can be called repeatedly to parse multiple entries. +// +// If no entries were found in the input then err will be io.EOF. Otherwise a +// non-nil err value indicates a parse error. +func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + // Strip out the beginning of the known_host key. + // This is either an optional marker or a (set of) hostname(s). + keyFields := bytes.Fields(in) + if len(keyFields) < 3 || len(keyFields) > 5 { + return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data") + } + + // keyFields[0] is either "@cert-authority", "@revoked" or a comma separated + // list of hosts + marker := "" + if keyFields[0][0] == '@' { + marker = string(keyFields[0][1:]) + keyFields = keyFields[1:] + } + + hosts := string(keyFields[0]) + // keyFields[1] contains the key type (e.g. “ssh-rsa”). + // However, that information is duplicated inside the + // base64-encoded key and so is ignored here. + + key := bytes.Join(keyFields[2:], []byte(" ")) + if pubKey, comment, err = parseAuthorizedKey(key); err != nil { + return "", nil, nil, "", nil, err + } + + return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil + } + + return "", nil, nil, "", nil, io.EOF +} + +// ParseAuthorizedKeys parses a public key from an authorized_keys +// file used in OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil + } + + // No key type recognised. Maybe there's an options field at + // the beginning. + var b byte + inQuote := false + var candidateOptions []string + optionStart := 0 + for i, b = range in { + isEnd := !inQuote && (b == ' ' || b == '\t') + if (b == ',' && !inQuote) || isEnd { + if i-optionStart > 0 { + candidateOptions = append(candidateOptions, string(in[optionStart:i])) + } + optionStart = i + 1 + } + if isEnd { + break + } + if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { + inQuote = !inQuote + } + } + for i < len(in) && (in[i] == ' ' || in[i] == '\t') { + i++ + } + if i == len(in) { + // Invalid line: unmatched quote + in = rest + continue + } + + in = in[i:] + i = bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + options = candidateOptions + return out, comment, options, rest, nil + } + + in = rest + continue + } + + return nil, "", nil, nil, errors.New("ssh: no key found") +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + algo, in, ok := parseString(in) + if !ok { + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") + } + + return out, err +} + +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. +func MarshalAuthorizedKey(key PublicKey) []byte { + b := &bytes.Buffer{} + b.WriteString(key.Type()) + b.WriteByte(' ') + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(key.Marshal()) + e.Close() + b.WriteByte('\n') + return b.Bytes() +} + +// PublicKey is an abstraction of different types of public keys. +type PublicKey interface { + // Type returns the key's type, e.g. "ssh-rsa". + Type() string + + // Marshal returns the serialized key data in SSH wire format, + // with the name prefix. + Marshal() []byte + + // Verify that sig is a signature on the given data using this + // key. This function will hash the data appropriately first. + Verify(data []byte, sig *Signature) error +} + +// CryptoPublicKey, if implemented by a PublicKey, +// returns the underlying crypto.PublicKey form of the key. +type CryptoPublicKey interface { + CryptoPublicKey() crypto.PublicKey +} + +// A Signer can create signatures that verify against a public key. +type Signer interface { + // PublicKey returns an associated PublicKey instance. + PublicKey() PublicKey + + // Sign returns raw signature for the given data. This method + // will apply the hash specified for the keytype to the data. + Sign(rand io.Reader, data []byte) (*Signature, error) +} + +type rsaPublicKey rsa.PublicKey + +func (r *rsaPublicKey) Type() string { + return "ssh-rsa" +} + +// parseRSA parses an RSA key according to RFC 4253, section 6.6. +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") + } + + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil +} + +func (r *rsaPublicKey) Marshal() []byte { + e := new(big.Int).SetInt64(int64(r.E)) + // RSA publickey struct layout should match the struct used by + // parseRSACert in the x/crypto/ssh/agent package. + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) +} + +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != r.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) +} + +func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*rsa.PublicKey)(r) +} + +type dsaPublicKey dsa.PublicKey + +func (r *dsaPublicKey) Type() string { + return "ssh-dss" +} + +// parseDSA parses an DSA key according to RFC 4253, section 6.6. +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := &dsaPublicKey{ + Parameters: dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + }, + Y: w.Y, + } + return key, w.Rest, nil +} + +func (k *dsaPublicKey) Marshal() []byte { + // DSA publickey struct layout should match the struct used by + // parseDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } + + return Marshal(&w) +} + +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 4253, section 6.6, + // The value for 'dss_signature_blob' is encoded as a string containing + // r, followed by s (which are 160-bit integers, without lengths or + // padding, unsigned, and in network byte order). + // For DSS purposes, sig.Blob should be exactly 40 bytes in length. + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") + } + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*dsa.PublicKey)(k) +} + +type dsaPrivateKey struct { + *dsa.PrivateKey +} + +func (k *dsaPrivateKey) PublicKey() PublicKey { + return (*dsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := dsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, 40) + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +type ecdsaPublicKey ecdsa.PublicKey + +func (key *ecdsaPublicKey) Type() string { + return "ecdsa-sha2-" + key.nistID() +} + +func (key *ecdsaPublicKey) nistID() string { + switch key.Params().BitSize { + case 256: + return "nistp256" + case 384: + return "nistp384" + case 521: + return "nistp521" + } + panic("ssh: unsupported ecdsa key size") +} + +type ed25519PublicKey ed25519.PublicKey + +func (key ed25519PublicKey) Type() string { + return KeyAlgoED25519 +} + +func parseED25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := ed25519.PublicKey(w.KeyBytes) + + return (ed25519PublicKey)(key), w.Rest, nil +} + +func (key ed25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + }{ + KeyAlgoED25519, + []byte(key), + } + return Marshal(&w) +} + +func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error { + if sig.Format != key.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) + } + + edKey := (ed25519.PublicKey)(key) + if ok := ed25519.Verify(edKey, b, sig.Blob); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return ed25519.PublicKey(k) +} + +func supportedEllipticCurve(curve elliptic.Curve) bool { + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() +} + +// ecHash returns the hash to match the given elliptic curve, see RFC +// 5656, section 6.2.1 +func ecHash(curve elliptic.Curve) crypto.Hash { + bitSize := curve.Params().BitSize + switch { + case bitSize <= 256: + return crypto.SHA256 + case bitSize <= 384: + return crypto.SHA384 + } + return crypto.SHA512 +} + +// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(ecdsa.PublicKey) + + switch w.Curve { + case "nistp256": + key.Curve = elliptic.P256() + case "nistp384": + key.Curve = elliptic.P384() + case "nistp521": + key.Curve = elliptic.P521() + default: + return nil, nil, errors.New("ssh: unsupported curve") + } + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + return (*ecdsaPublicKey)(key), w.Rest, nil +} + +func (key *ecdsaPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) + // ECDSA publickey struct layout should match the struct used by + // parseECDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + ID string + Key []byte + }{ + key.Type(), + key.nistID(), + keyBytes, + } + + return Marshal(&w) +} + +func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != key.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) + } + + h := ecHash(key.Curve).New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 5656, section 3.1.2, + // The ecdsa_signature_blob value has the following specific encoding: + // mpint r + // mpint s + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*ecdsa.PublicKey)(k) +} + +// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey, +// *ecdsa.PrivateKey or any other crypto.Signer and returns a corresponding +// Signer instance. ECDSA keys must use P-256, P-384 or P-521. +func NewSignerFromKey(key interface{}) (Signer, error) { + switch key := key.(type) { + case crypto.Signer: + return NewSignerFromSigner(key) + case *dsa.PrivateKey: + return &dsaPrivateKey{key}, nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +type wrappedSigner struct { + signer crypto.Signer + pubKey PublicKey +} + +// NewSignerFromSigner takes any crypto.Signer implementation and +// returns a corresponding Signer interface. This can be used, for +// example, with keys kept in hardware modules. +func NewSignerFromSigner(signer crypto.Signer) (Signer, error) { + pubKey, err := NewPublicKey(signer.Public()) + if err != nil { + return nil, err + } + + return &wrappedSigner{signer, pubKey}, nil +} + +func (s *wrappedSigner) PublicKey() PublicKey { + return s.pubKey +} + +func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + var hashFunc crypto.Hash + + switch key := s.pubKey.(type) { + case *rsaPublicKey, *dsaPublicKey: + hashFunc = crypto.SHA1 + case *ecdsaPublicKey: + hashFunc = ecHash(key.Curve) + case ed25519PublicKey: + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } + + var digest []byte + if hashFunc != 0 { + h := hashFunc.New() + h.Write(data) + digest = h.Sum(nil) + } else { + digest = data + } + + signature, err := s.signer.Sign(rand, digest, hashFunc) + if err != nil { + return nil, err + } + + // crypto.Signer.Sign is expected to return an ASN.1-encoded signature + // for ECDSA and DSA, but that's not the encoding expected by SSH, so + // re-encode. + switch s.pubKey.(type) { + case *ecdsaPublicKey, *dsaPublicKey: + type asn1Signature struct { + R, S *big.Int + } + asn1Sig := new(asn1Signature) + _, err := asn1.Unmarshal(signature, asn1Sig) + if err != nil { + return nil, err + } + + switch s.pubKey.(type) { + case *ecdsaPublicKey: + signature = Marshal(asn1Sig) + + case *dsaPublicKey: + signature = make([]byte, 40) + r := asn1Sig.R.Bytes() + s := asn1Sig.S.Bytes() + copy(signature[20-len(r):20], r) + copy(signature[40-len(s):40], s) + } + } + + return &Signature{ + Format: s.pubKey.Type(), + Blob: signature, + }, nil +} + +// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, +// or ed25519.PublicKey returns a corresponding PublicKey instance. +// ECDSA keys must use P-256, P-384 or P-521. +func NewPublicKey(key interface{}) (PublicKey, error) { + switch key := key.(type) { + case *rsa.PublicKey: + return (*rsaPublicKey)(key), nil + case *ecdsa.PublicKey: + if !supportedEllipticCurve(key.Curve) { + return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.") + } + return (*ecdsaPublicKey)(key), nil + case *dsa.PublicKey: + return (*dsaPublicKey)(key), nil + case ed25519.PublicKey: + return (ed25519PublicKey)(key), nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. +func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// encryptedBlock tells whether a private key is +// encrypted by examining its Proc-Type header +// for a mention of ENCRYPTED +// according to RFC 1421 Section 4.6.1.1. +func encryptedBlock(block *pem.Block) bool { + return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED") +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It +// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + if encryptedBlock(block) { + return nil, errors.New("ssh: cannot decode encrypted private keys") + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(block.Bytes) + case "OPENSSH PRIVATE KEY": + return parseOpenSSHPrivateKey(block.Bytes) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { + Version int + P *big.Int + Q *big.Int + G *big.Int + Priv *big.Int + Pub *big.Int + } + rest, err := asn1.Unmarshal(der, &k) + if err != nil { + return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) + } + if len(rest) > 0 { + return nil, errors.New("ssh: garbage after DSA key") + } + + return &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Priv, + }, + X: k.Pub, + }, nil +} + +// Implemented based on the documentation at +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key +func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { + magic := append([]byte("openssh-key-v1"), 0) + if !bytes.Equal(magic, key[0:len(magic)]) { + return nil, errors.New("ssh: invalid openssh private key format") + } + remaining := key[len(magic):] + + var w struct { + CipherName string + KdfName string + KdfOpts string + NumKeys uint32 + PubKey []byte + PrivKeyBlock []byte + } + + if err := Unmarshal(remaining, &w); err != nil { + return nil, err + } + + pk1 := struct { + Check1 uint32 + Check2 uint32 + Keytype string + Pub []byte + Priv []byte + Comment string + Pad []byte `ssh:"rest"` + }{} + + if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil { + return nil, err + } + + if pk1.Check1 != pk1.Check2 { + return nil, errors.New("ssh: checkint mismatch") + } + + // we only handle ed25519 keys currently + if pk1.Keytype != KeyAlgoED25519 { + return nil, errors.New("ssh: unhandled key type") + } + + for i, b := range pk1.Pad { + if int(b) != i+1 { + return nil, errors.New("ssh: padding not as expected") + } + } + + if len(pk1.Priv) != ed25519.PrivateKeySize { + return nil, errors.New("ssh: private key unexpected length") + } + + pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) + copy(pk, pk1.Priv) + return &pk, nil +} + +// FingerprintLegacyMD5 returns the user presentation of the key's +// fingerprint as described by RFC 4716 section 4. +func FingerprintLegacyMD5(pubKey PublicKey) string { + md5sum := md5.Sum(pubKey.Marshal()) + hexarray := make([]string, len(md5sum)) + for i, c := range md5sum { + hexarray[i] = hex.EncodeToString([]byte{c}) + } + return strings.Join(hexarray, ":") +} + +// FingerprintSHA256 returns the user presentation of the key's +// fingerprint as unpadded base64 encoded sha256 hash. +// This format was introduced from OpenSSH 6.8. +// https://www.openssh.com/txt/release-6.8 +// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding) +func FingerprintSHA256(pubKey PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} diff --git a/vendor/golang.org/x/crypto/ssh/keys_test.go b/vendor/golang.org/x/crypto/ssh/keys_test.go new file mode 100644 index 0000000..a65e87e --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/keys_test.go @@ -0,0 +1,474 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "reflect" + "strings" + "testing" + + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh/testdata" +) + +func rawKey(pub PublicKey) interface{} { + switch k := pub.(type) { + case *rsaPublicKey: + return (*rsa.PublicKey)(k) + case *dsaPublicKey: + return (*dsa.PublicKey)(k) + case *ecdsaPublicKey: + return (*ecdsa.PublicKey)(k) + case ed25519PublicKey: + return (ed25519.PublicKey)(k) + case *Certificate: + return k + } + panic("unknown key type") +} + +func TestKeyMarshalParse(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + roundtrip, err := ParsePublicKey(pub.Marshal()) + if err != nil { + t.Errorf("ParsePublicKey(%T): %v", pub, err) + } + + k1 := rawKey(pub) + k2 := rawKey(roundtrip) + + if !reflect.DeepEqual(k1, k2) { + t.Errorf("got %#v in roundtrip, want %#v", k2, k1) + } + } +} + +func TestUnsupportedCurves(t *testing.T) { + raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err) + } + + if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err) + } +} + +func TestNewPublicKey(t *testing.T) { + for _, k := range testSigners { + raw := rawKey(k.PublicKey()) + // Skip certificates, as NewPublicKey does not support them. + if _, ok := raw.(*Certificate); ok { + continue + } + pub, err := NewPublicKey(raw) + if err != nil { + t.Errorf("NewPublicKey(%#v): %v", raw, err) + } + if !reflect.DeepEqual(k.PublicKey(), pub) { + t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) + } + } +} + +func TestKeySignVerify(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + + data := []byte("sign me") + sig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } +} + +func TestParseRSAPrivateKey(t *testing.T) { + key := testPrivateKeys["rsa"] + + rsa, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *rsa.PrivateKey", rsa) + } + + if err := rsa.Validate(); err != nil { + t.Errorf("Validate: %v", err) + } +} + +func TestParseECPrivateKey(t *testing.T) { + key := testPrivateKeys["ecdsa"] + + ecKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) + } + + if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { + t.Fatalf("public key does not validate.") + } +} + +// See Issue https://github.com/golang/go/issues/6650. +func TestParseEncryptedPrivateKeysFails(t *testing.T) { + const wantSubstring = "encrypted" + for i, tt := range testdata.PEMEncryptedKeys { + _, err := ParsePrivateKey(tt.PEMBytes) + if err == nil { + t.Errorf("#%d key %s: ParsePrivateKey successfully parsed, expected an error", i, tt.Name) + continue + } + + if !strings.Contains(err.Error(), wantSubstring) { + t.Errorf("#%d key %s: got error %q, want substring %q", i, tt.Name, err, wantSubstring) + } + } +} + +func TestParseDSA(t *testing.T) { + // We actually exercise the ParsePrivateKey codepath here, as opposed to + // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go + // uses. + s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) + if err != nil { + t.Fatalf("ParsePrivateKey returned error: %s", err) + } + + data := []byte("sign me") + sig, err := s.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("dsa.Sign: %v", err) + } + + if err := s.PublicKey().Verify(data, sig); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +// Tests for authorized_keys parsing. + +// getTestKey returns a public key, and its base64 encoding. +func getTestKey() (PublicKey, string) { + k := testPublicKeys["rsa"] + + b := &bytes.Buffer{} + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(k.Marshal()) + e.Close() + + return k, b.String() +} + +func TestMarshalParsePublicKey(t *testing.T) { + pub, pubSerialized := getTestKey() + line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) + + authKeys := MarshalAuthorizedKey(pub) + actualFields := strings.Fields(string(authKeys)) + if len(actualFields) == 0 { + t.Fatalf("failed authKeys: %v", authKeys) + } + + // drop the comment + expectedFields := strings.Fields(line)[0:2] + + if !reflect.DeepEqual(actualFields, expectedFields) { + t.Errorf("got %v, expected %v", actualFields, expectedFields) + } + + actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) + if err != nil { + t.Fatalf("cannot parse %v: %v", line, err) + } + if !reflect.DeepEqual(actPub, pub) { + t.Errorf("got %v, expected %v", actPub, pub) + } +} + +type authResult struct { + pubKey PublicKey + options []string + comments string + rest string + ok bool +} + +func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { + rest := authKeys + var values []authResult + for len(rest) > 0 { + var r authResult + var err error + r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) + r.ok = (err == nil) + t.Log(err) + r.rest = string(rest) + values = append(values, r) + } + + if !reflect.DeepEqual(values, expected) { + t.Errorf("got %#v, expected %#v", values, expected) + } +} + +func TestAuthorizedKeyBasic(t *testing.T) { + pub, pubSerialized := getTestKey() + line := "ssh-rsa " + pubSerialized + " user@host" + testAuthorizedKeys(t, []byte(line), + []authResult{ + {pub, nil, "user@host", "", true}, + }) +} + +func TestAuth(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithOptions := []string{ + `# comments to ignore before any keys...`, + ``, + `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, + `# comments to ignore, along with a blank line`, + ``, + `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, + ``, + `# more comments, plus a invalid entry`, + `ssh-rsa data-that-will-not-parse user@host3`, + } + for _, eol := range []string{"\n", "\r\n"} { + authOptions := strings.Join(authWithOptions, eol) + rest2 := strings.Join(authWithOptions[3:], eol) + rest3 := strings.Join(authWithOptions[6:], eol) + testAuthorizedKeys(t, []byte(authOptions), []authResult{ + {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, + {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, + {nil, nil, "", "", false}, + }) + } +} + +func TestAuthWithQuotedSpaceInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedCommaInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ + {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedQuoteInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) + authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ + {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) + + testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ + {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, + }) +} + +func TestAuthWithInvalidSpace(t *testing.T) { + _, pubSerialized := getTestKey() + authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +#more to follow but still no valid keys`) + testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ + {nil, nil, "", "", false}, + }) +} + +func TestAuthWithMissingQuote(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) + + testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ + {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, + }) +} + +func TestInvalidEntry(t *testing.T) { + authInvalid := []byte(`ssh-rsa`) + _, _, _, _, err := ParseAuthorizedKey(authInvalid) + if err == nil { + t.Errorf("got valid entry for %q", authInvalid) + } +} + +var knownHostsParseTests = []struct { + input string + err string + + marker string + comment string + hosts []string + rest string +}{ + { + "", + "EOF", + + "", "", nil, "", + }, + { + "# Just a comment", + "EOF", + + "", "", nil, "", + }, + { + " \t ", + "EOF", + + "", "", nil, "", + }, + { + "localhost ssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line", + "", + + "", "comment comment", []string{"localhost"}, "next line", + }, + { + "localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}", + "", + + "marker", "", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd", + "short read", + + "", "", nil, "", + }, +} + +func TestKnownHostsParsing(t *testing.T) { + rsaPub, rsaPubSerialized := getTestKey() + + for i, test := range knownHostsParseTests { + var expectedKey PublicKey + const rsaKeyToken = "{RSAPUB}" + + input := test.input + if strings.Contains(input, rsaKeyToken) { + expectedKey = rsaPub + input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1) + } + + marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input)) + if err != nil { + if len(test.err) == 0 { + t.Errorf("#%d: unexpectedly failed with %q", i, err) + } else if !strings.Contains(err.Error(), test.err) { + t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err) + } + continue + } else if len(test.err) != 0 { + t.Errorf("#%d: succeeded but expected error including %q", i, test.err) + continue + } + + if !reflect.DeepEqual(expectedKey, pubKey) { + t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey) + } + + if marker != test.marker { + t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker) + } + + if comment != test.comment { + t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment) + } + + if !reflect.DeepEqual(test.hosts, hosts) { + t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts) + } + + if rest := string(rest); rest != test.rest { + t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest) + } + } +} + +func TestFingerprintLegacyMD5(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintLegacyMD5(pub) + want := "fb:61:6d:1a:e3:f0:95:45:3c:a0:79:be:4a:93:63:66" // ssh-keygen -lf -E md5 rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestFingerprintSHA256(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintSHA256(pub) + want := "SHA256:Anr3LjZK8YVpjrxu79myrW9Hrb/wpcMNpVvTq/RcBm8" // ssh-keygen -lf rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/mac.go b/vendor/golang.org/x/crypto/ssh/mac.go new file mode 100644 index 0000000..07744ad --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/mac.go @@ -0,0 +1,57 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Message authentication support + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "hash" +) + +type macMode struct { + keySize int + new func(key []byte) hash.Hash +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum(in []byte) []byte { + out := t.hmac.Sum(in) + return out[:len(in)+t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } + +var macModes = map[string]*macMode{ + "hmac-sha2-256": {32, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha1": {20, func(key []byte) hash.Hash { + return hmac.New(sha1.New, key) + }}, + "hmac-sha1-96": {20, func(key []byte) hash.Hash { + return truncatingMAC{12, hmac.New(sha1.New, key)} + }}, +} diff --git a/vendor/golang.org/x/crypto/ssh/mempipe_test.go b/vendor/golang.org/x/crypto/ssh/mempipe_test.go new file mode 100644 index 0000000..8697cd6 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/mempipe_test.go @@ -0,0 +1,110 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" + "testing" +) + +// An in-memory packetConn. It is safe to call Close and writePacket +// from different goroutines. +type memTransport struct { + eof bool + pending [][]byte + write *memTransport + sync.Mutex + *sync.Cond +} + +func (t *memTransport) readPacket() ([]byte, error) { + t.Lock() + defer t.Unlock() + for { + if len(t.pending) > 0 { + r := t.pending[0] + t.pending = t.pending[1:] + return r, nil + } + if t.eof { + return nil, io.EOF + } + t.Cond.Wait() + } +} + +func (t *memTransport) closeSelf() error { + t.Lock() + defer t.Unlock() + if t.eof { + return io.EOF + } + t.eof = true + t.Cond.Broadcast() + return nil +} + +func (t *memTransport) Close() error { + err := t.write.closeSelf() + t.closeSelf() + return err +} + +func (t *memTransport) writePacket(p []byte) error { + t.write.Lock() + defer t.write.Unlock() + if t.write.eof { + return io.EOF + } + c := make([]byte, len(p)) + copy(c, p) + t.write.pending = append(t.write.pending, c) + t.write.Cond.Signal() + return nil +} + +func memPipe() (a, b packetConn) { + t1 := memTransport{} + t2 := memTransport{} + t1.write = &t2 + t2.write = &t1 + t1.Cond = sync.NewCond(&t1.Mutex) + t2.Cond = sync.NewCond(&t2.Mutex) + return &t1, &t2 +} + +func TestMemPipe(t *testing.T) { + a, b := memPipe() + if err := a.writePacket([]byte{42}); err != nil { + t.Fatalf("writePacket: %v", err) + } + if err := a.Close(); err != nil { + t.Fatal("Close: ", err) + } + p, err := b.readPacket() + if err != nil { + t.Fatal("readPacket: ", err) + } + if len(p) != 1 || p[0] != 42 { + t.Fatalf("got %v, want {42}", p) + } + p, err = b.readPacket() + if err != io.EOF { + t.Fatalf("got %v, %v, want EOF", p, err) + } +} + +func TestDoubleClose(t *testing.T) { + a, _ := memPipe() + err := a.Close() + if err != nil { + t.Errorf("Close: %v", err) + } + err = a.Close() + if err != io.EOF { + t.Errorf("expect EOF on double close.") + } +} diff --git a/vendor/golang.org/x/crypto/ssh/messages.go b/vendor/golang.org/x/crypto/ssh/messages.go new file mode 100644 index 0000000..e6ecd3a --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/messages.go @@ -0,0 +1,758 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" + "strings" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from [SSH-PARAMETERS]. +const ( + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 + + // Standard authentication messages + msgUserAuthSuccess = 52 + msgUserAuthBanner = 53 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// ssh tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) +} + +// See RFC 4253, section 7.1. +const msgKexInit = 20 + +type kexInitMsg struct { + Cookie [16]byte `sshtype:"20"` + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +// See RFC 4253, section 8. + +// Diffie-Helman +const msgKexDHInit = 30 + +type kexDHInitMsg struct { + X *big.Int `sshtype:"30"` +} + +const msgKexECDHInit = 30 + +type kexECDHInitMsg struct { + ClientPubKey []byte `sshtype:"30"` +} + +const msgKexECDHReply = 31 + +type kexECDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + EphemeralPubKey []byte + Signature []byte +} + +const msgKexDHReply = 31 + +type kexDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + Y *big.Int + Signature []byte +} + +// See RFC 4253, section 10. +const msgServiceRequest = 5 + +type serviceRequestMsg struct { + Service string `sshtype:"5"` +} + +// See RFC 4253, section 10. +const msgServiceAccept = 6 + +type serviceAcceptMsg struct { + Service string `sshtype:"6"` +} + +// See RFC 4252, section 5. +const msgUserAuthRequest = 50 + +type userAuthRequestMsg struct { + User string `sshtype:"50"` + Service string + Method string + Payload []byte `ssh:"rest"` +} + +// Used for debug printouts of packets. +type userAuthSuccessMsg struct { +} + +// See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + +type userAuthFailureMsg struct { + Methods []string `sshtype:"51"` + PartialSuccess bool +} + +// See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + +type userAuthInfoRequestMsg struct { + User string `sshtype:"60"` + Instruction string + DeprecatedLanguage string + NumPrompts uint32 + Prompts []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersId uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelExtendedData = 95 +const msgChannelData = 94 + +// Used for debug print outs of packets. +type channelDataMsg struct { + PeersId uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersId uint32 `sshtype:"91"` + MyId uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + +type channelOpenFailureMsg struct { + PeersId uint32 `sshtype:"92"` + Reason RejectionReason + Message string + Language string +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersId uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersId uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersId uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersId uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersId uint32 `sshtype:"96"` +} + +// See RFC 4254, section 4 +const msgGlobalRequest = 80 + +type globalRequestMsg struct { + Type string `sshtype:"80"` + WantReply bool + Data []byte `ssh:"rest"` +} + +// See RFC 4254, section 4 +const msgRequestSuccess = 81 + +type globalRequestSuccessMsg struct { + Data []byte `ssh:"rest" sshtype:"81"` +} + +// See RFC 4254, section 4 +const msgRequestFailure = 82 + +type globalRequestFailureMsg struct { + Data []byte `ssh:"rest" sshtype:"82"` +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersId uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + +type userAuthPubKeyOkMsg struct { + Algo string `sshtype:"60"` + PubKey []byte +} + +// typeTags returns the possible type bytes for the given reflect.Type, which +// should be a struct. The possible values are separated by a '|' character. +func typeTags(structType reflect.Type) (tags []byte) { + tagStr := structType.Field(0).Tag.Get("sshtype") + + for _, tag := range strings.Split(tagStr, "|") { + i, err := strconv.Atoi(tag) + if err == nil { + tags = append(tags, byte(i)) + } + } + + return tags +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem + } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") + +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a '|'-separated set of numbers +// in decimal, the packet must start with one of those numbers. In +// case of error, Unmarshal returns a ParseError or +// UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { + v := reflect.ValueOf(out).Elem() + structType := v.Type() + expectedTypes := typeTags(structType) + + var expectedType byte + if len(expectedTypes) > 0 { + expectedType = expectedTypes[0] + } + + if len(data) == 0 { + return parseError(expectedType) + } + + if len(expectedTypes) > 0 { + goodType := false + for _, e := range expectedTypes { + if e > 0 && data[0] == e { + goodType = true + break + } + } + if !goodType { + return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) + } + data = data[1:] + } + + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(data) < 1 { + return errShortRead + } + field.SetBool(data[0] != 0) + data = data[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + return fieldError(structType, i, "array of unsupported type") + } + if len(data) < t.Len() { + return errShortRead + } + for j, n := 0, t.Len(); j < n; j++ { + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead + } + field.SetUint(u64) + case reflect.Uint32: + var u32 uint32 + if u32, data, ok = parseUint32(data); !ok { + return errShortRead + } + field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] + case reflect.String: + var s []byte + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag.Get("ssh") == "rest" { + field.Set(reflect.ValueOf(data)) + data = nil + } else { + var s []byte + if s, data, ok = parseString(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, data, ok = parseNameList(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(nl)) + default: + return fieldError(structType, i, "slice of unsupported type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, data, ok = parseInt(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(n)) + } else { + return fieldError(structType, i, "pointer to unsupported type") + } + default: + return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) + } + } + + if len(data) != 0 { + return parseError(expectedType) + } + + return nil +} + +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { + out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgTypes := typeTags(v.Type()) + if len(msgTypes) > 0 { + out = append(out, msgTypes[0]) + } + + for i, n := 0, v.NumField(); i < n; i++ { + field := v.Field(i) + switch t := field.Type(); t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) + } + for j, l := 0, t.Len(); j < l; j++ { + out = append(out, uint8(field.Index(j).Uint())) + } + case reflect.Uint32: + out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) + case reflect.String: + s := field.String() + out = appendInt(out, len(s)) + out = append(out, s...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if v.Type().Field(i).Tag.Get("ssh") != "rest" { + out = appendInt(out, field.Len()) + } + out = append(out, field.Bytes()...) + case reflect.String: + offset := len(out) + out = appendU32(out, 0) + if n := field.Len(); n > 0 { + for j := 0; j < n; j++ { + f := field.Index(j) + if j != 0 { + out = append(out, ',') + } + out = append(out, f.String()...) + } + // overwrite length value + binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) + } + default: + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + in = in[4:] + if uint32(len(in)) < length { + return + } + out = in[:length] + rest = in[length:] + ok = true + return +} + +var ( + comma = []byte{','} + emptyNameList = []string{} +) + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + out = emptyNameList + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} + +func parseUint64(in []byte) (uint64, []byte, bool) { + if len(in) < 8 { + return 0, nil, false + } + return binary.BigEndian.Uint64(in), in[8:], true +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalUint32(to []byte, n uint32) []byte { + binary.BigEndian.PutUint32(to, n) + return to[4:] +} + +func marshalUint64(to []byte, n uint64) []byte { + binary.BigEndian.PutUint64(to, n) + return to[8:] +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(n int) int { + return 4 + n +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) + +// Decode a packet into its corresponding message. +func decode(packet []byte) (interface{}, error) { + var msg interface{} + switch packet[0] { + case msgDisconnect: + msg = new(disconnectMsg) + case msgServiceRequest: + msg = new(serviceRequestMsg) + case msgServiceAccept: + msg = new(serviceAcceptMsg) + case msgKexInit: + msg = new(kexInitMsg) + case msgKexDHInit: + msg = new(kexDHInitMsg) + case msgKexDHReply: + msg = new(kexDHReplyMsg) + case msgUserAuthRequest: + msg = new(userAuthRequestMsg) + case msgUserAuthSuccess: + return new(userAuthSuccessMsg), nil + case msgUserAuthFailure: + msg = new(userAuthFailureMsg) + case msgUserAuthPubKeyOk: + msg = new(userAuthPubKeyOkMsg) + case msgGlobalRequest: + msg = new(globalRequestMsg) + case msgRequestSuccess: + msg = new(globalRequestSuccessMsg) + case msgRequestFailure: + msg = new(globalRequestFailureMsg) + case msgChannelOpen: + msg = new(channelOpenMsg) + case msgChannelData: + msg = new(channelDataMsg) + case msgChannelOpenConfirm: + msg = new(channelOpenConfirmMsg) + case msgChannelOpenFailure: + msg = new(channelOpenFailureMsg) + case msgChannelWindowAdjust: + msg = new(windowAdjustMsg) + case msgChannelEOF: + msg = new(channelEOFMsg) + case msgChannelClose: + msg = new(channelCloseMsg) + case msgChannelRequest: + msg = new(channelRequestMsg) + case msgChannelSuccess: + msg = new(channelRequestSuccessMsg) + case msgChannelFailure: + msg = new(channelRequestFailureMsg) + default: + return nil, unexpectedMessageError(0, packet[0]) + } + if err := Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/messages_test.go b/vendor/golang.org/x/crypto/ssh/messages_test.go new file mode 100644 index 0000000..e790764 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/messages_test.go @@ -0,0 +1,288 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +var intLengthTests = []struct { + val, length int +}{ + {0, 4 + 0}, + {1, 4 + 1}, + {127, 4 + 1}, + {128, 4 + 2}, + {-1, 4 + 1}, +} + +func TestIntLength(t *testing.T) { + for _, test := range intLengthTests { + v := new(big.Int).SetInt64(int64(test.val)) + length := intLength(v) + if length != test.length { + t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) + } + } +} + +type msgAllTypes struct { + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + Bytes []byte + Int *big.Int + Rest []byte `ssh:"rest"` +} + +func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { + m := &msgAllTypes{} + m.Bool = rand.Intn(2) == 1 + randomBytes(m.Array[:], rand) + m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) + m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) + m.Uint8 = uint8(rand.Intn(1 << 8)) + m.String = string(m.Array[:]) + m.Strings = randomNameList(rand) + m.Bytes = m.Array[:] + m.Int = randomInt(rand) + m.Rest = m.Array[:] + return reflect.ValueOf(m) +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + iface := &msgAllTypes{} + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("failed to create value") + break + } + + m1 := v.Elem().Interface() + m2 := iface + + marshaled := Marshal(m1) + if err := Unmarshal(marshaled, m2); err != nil { + t.Errorf("Unmarshal %#v: %s", m1, err) + break + } + + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) + break + } + } +} + +func TestUnmarshalEmptyPacket(t *testing.T) { + var b []byte + var m channelRequestSuccessMsg + if err := Unmarshal(b, &m); err == nil { + t.Fatalf("unmarshal of empty slice succeeded") + } +} + +func TestUnmarshalUnexpectedPacket(t *testing.T) { + type S struct { + I uint32 `sshtype:"43"` + S string + B bool + } + + s := S{11, "hello", true} + packet := Marshal(s) + packet[0] = 42 + roundtrip := S{} + err := Unmarshal(packet, &roundtrip) + if err == nil { + t.Fatal("expected error, not nil") + } +} + +func TestMarshalPtr(t *testing.T) { + s := struct { + S string + }{"hello"} + + m1 := Marshal(s) + m2 := Marshal(&s) + if !bytes.Equal(m1, m2) { + t.Errorf("got %q, want %q for marshaled pointer", m2, m1) + } +} + +func TestBareMarshalUnmarshal(t *testing.T) { + type S struct { + I uint32 + S string + B bool + } + + s := S{42, "hello", true} + packet := Marshal(s) + roundtrip := S{} + Unmarshal(packet, &roundtrip) + + if !reflect.DeepEqual(s, roundtrip) { + t.Errorf("got %#v, want %#v", roundtrip, s) + } +} + +func TestBareMarshal(t *testing.T) { + type S2 struct { + I uint32 + } + s := S2{42} + packet := Marshal(s) + i, rest, ok := parseUint32(packet) + if len(rest) > 0 || !ok { + t.Errorf("parseInt(%q): parse error", packet) + } + if i != s.I { + t.Errorf("got %d, want %d", i, s.I) + } +} + +func TestUnmarshalShortKexInitPacket(t *testing.T) { + // This used to panic. + // Issue 11348 + packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} + kim := &kexInitMsg{} + if err := Unmarshal(packet, kim); err == nil { + t.Error("truncated packet unmarshaled without error") + } +} + +func TestMarshalMultiTag(t *testing.T) { + var res struct { + A uint32 `sshtype:"1|2"` + } + + good1 := struct { + A uint32 `sshtype:"1"` + }{ + 1, + } + good2 := struct { + A uint32 `sshtype:"2"` + }{ + 1, + } + + if e := Unmarshal(Marshal(good1), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + if e := Unmarshal(Marshal(good2), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + bad1 := struct { + A uint32 `sshtype:"3"` + }{ + 1, + } + if e := Unmarshal(Marshal(bad1), &res); e == nil { + t.Errorf("bad struct unmarshaled without error") + } +} + +func randomBytes(out []byte, rand *rand.Rand) { + for i := 0; i < len(out); i++ { + out[i] = byte(rand.Int31()) + } +} + +func randomNameList(rand *rand.Rand) []string { + ret := make([]string, rand.Int31()&15) + for i := range ret { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + ret[i] = string(s) + } + return ret +} + +func randomInt(rand *rand.Rand) *big.Int { + return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) +} + +func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + ki := &kexInitMsg{} + randomBytes(ki.Cookie[:], rand) + ki.KexAlgos = randomNameList(rand) + ki.ServerHostKeyAlgos = randomNameList(rand) + ki.CiphersClientServer = randomNameList(rand) + ki.CiphersServerClient = randomNameList(rand) + ki.MACsClientServer = randomNameList(rand) + ki.MACsServerClient = randomNameList(rand) + ki.CompressionClientServer = randomNameList(rand) + ki.CompressionServerClient = randomNameList(rand) + ki.LanguagesClientServer = randomNameList(rand) + ki.LanguagesServerClient = randomNameList(rand) + if rand.Int31()&1 == 1 { + ki.FirstKexFollows = true + } + return reflect.ValueOf(ki) +} + +func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + dhi := &kexDHInitMsg{} + dhi.X = randomInt(rand) + return reflect.ValueOf(dhi) +} + +var ( + _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + + _kexInit = Marshal(_kexInitMsg) + _kexDHInit = Marshal(_kexDHInitMsg) +) + +func BenchmarkMarshalKexInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexInitMsg) + } +} + +func BenchmarkUnmarshalKexInitMsg(b *testing.B) { + m := new(kexInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexInit, m) + } +} + +func BenchmarkMarshalKexDHInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexDHInitMsg) + } +} + +func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { + m := new(kexDHInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexDHInit, m) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/mux.go b/vendor/golang.org/x/crypto/ssh/mux.go new file mode 100644 index 0000000..f3a3ddd --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/mux.go @@ -0,0 +1,330 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// When debugging, each new chanList instantiation has a different +// offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, 16), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, 16), + errCond: newCond(), + } + if debugMux { + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + } + + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + if debugMux { + log.Printf("send global(%d): %#v", m.chanList.offset, msg) + } + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return fmt.Errorf("ssh: invalid channel %d", id) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersId: msg.PeersId, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersId + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersId: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/mux_test.go b/vendor/golang.org/x/crypto/ssh/mux_test.go new file mode 100644 index 0000000..591aae8 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/mux_test.go @@ -0,0 +1,502 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "io/ioutil" + "sync" + "testing" +) + +func muxPair() (*mux, *mux) { + a, b := memPipe() + + s := newMux(a) + c := newMux(b) + + return s, c +} + +// Returns both ends of a channel, and the mux for the the 2nd +// channel. +func channelPair(t *testing.T) (*channel, *channel, *mux) { + c, s := muxPair() + + res := make(chan *channel, 1) + go func() { + newCh, ok := <-s.incomingChannels + if !ok { + t.Fatalf("No incoming channel") + } + if newCh.ChannelType() != "chan" { + t.Fatalf("got type %q want chan", newCh.ChannelType()) + } + ch, _, err := newCh.Accept() + if err != nil { + t.Fatalf("Accept %v", err) + } + res <- ch.(*channel) + }() + + ch, err := c.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + + return <-res, ch, c +} + +// Test that stderr and stdout can be addressed from different +// goroutines. This is intended for use with the race detector. +func TestMuxChannelExtendedThreadSafety(t *testing.T) { + writer, reader, mux := channelPair(t) + defer writer.Close() + defer reader.Close() + defer mux.Close() + + var wr, rd sync.WaitGroup + magic := "hello world" + + wr.Add(2) + go func() { + io.WriteString(writer, magic) + wr.Done() + }() + go func() { + io.WriteString(writer.Stderr(), magic) + wr.Done() + }() + + rd.Add(2) + go func() { + c, err := ioutil.ReadAll(reader) + if string(c) != magic { + t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + go func() { + c, err := ioutil.ReadAll(reader.Stderr()) + if string(c) != magic { + t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + + wr.Wait() + writer.CloseWrite() + rd.Wait() +} + +func TestMuxReadWrite(t *testing.T) { + s, c, mux := channelPair(t) + defer s.Close() + defer c.Close() + defer mux.Close() + + magic := "hello world" + magicExt := "hello stderr" + go func() { + _, err := s.Write([]byte(magic)) + if err != nil { + t.Fatalf("Write: %v", err) + } + _, err = s.Extended(1).Write([]byte(magicExt)) + if err != nil { + t.Fatalf("Write: %v", err) + } + err = s.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + }() + + var buf [1024]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + got := string(buf[:n]) + if got != magic { + t.Fatalf("server: got %q want %q", got, magic) + } + + n, err = c.Extended(1).Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + + got = string(buf[:n]) + if got != magicExt { + t.Fatalf("server: got %q want %q", got, magic) + } +} + +func TestMuxChannelOverflow(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + writer.Write(make([]byte, 1)) + wDone <- 1 + }() + writer.remoteWin.waitWriterBlocked() + + // Send 1 byte. + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], writer.remoteId) + marshalUint32(packet[5:], uint32(1)) + packet[9] = 42 + + if err := writer.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + if _, err := reader.SendRequest("hello", true, nil); err == nil { + t.Errorf("SendRequest succeeded.") + } + <-wDone +} + +func TestMuxChannelCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + reader.Close() + <-wDone +} + +func TestMuxConnectionCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + wDone := make(chan int, 1) + go func() { + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + wDone <- 1 + }() + + writer.remoteWin.waitWriterBlocked() + mux.Close() + <-wDone +} + +func TestMuxReject(t *testing.T) { + client, server := muxPair() + defer server.Close() + defer client.Close() + + go func() { + ch, ok := <-server.incomingChannels + if !ok { + t.Fatalf("Accept") + } + if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { + t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + } + ch.Reject(RejectionReason(42), "message") + }() + + ch, err := client.openChannel("ch", []byte("extra")) + if ch != nil { + t.Fatal("openChannel not rejected") + } + + ocf, ok := err.(*OpenChannelError) + if !ok { + t.Errorf("got %#v want *OpenChannelError", err) + } else if ocf.Reason != 42 || ocf.Message != "message" { + t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") + } + + want := "ssh: rejected: unknown reason 42 (message)" + if err.Error() != want { + t.Errorf("got %q, want %q", err.Error(), want) + } +} + +func TestMuxChannelRequest(t *testing.T) { + client, server, mux := channelPair(t) + defer server.Close() + defer client.Close() + defer mux.Close() + + var received int + var wg sync.WaitGroup + wg.Add(1) + go func() { + for r := range server.incomingRequests { + received++ + r.Reply(r.Type == "yes", nil) + } + wg.Done() + }() + _, err := client.SendRequest("yes", false, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + ok, err := client.SendRequest("yes", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + + if !ok { + t.Errorf("SendRequest(yes): %v", ok) + + } + + ok, err = client.SendRequest("no", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + if ok { + t.Errorf("SendRequest(no): %v", ok) + + } + + client.Close() + wg.Wait() + + if received != 3 { + t.Errorf("got %d requests, want %d", received, 3) + } +} + +func TestMuxGlobalRequest(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + var seen bool + go func() { + for r := range serverMux.incomingRequests { + seen = seen || r.Type == "peek" + if r.WantReply { + err := r.Reply(r.Type == "yes", + append([]byte(r.Type), r.Payload...)) + if err != nil { + t.Errorf("AckRequest: %v", err) + } + } + } + }() + + _, _, err := clientMux.SendRequest("peek", false, nil) + if err != nil { + t.Errorf("SendRequest: %v", err) + } + + ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) + if !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { + t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if !seen { + t.Errorf("never saw 'peek' request") + } +} + +func TestMuxGlobalRequestUnblock(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + result := make(chan error, 1) + go func() { + _, _, err := clientMux.SendRequest("hello", true, nil) + result <- err + }() + + <-serverMux.incomingRequests + serverMux.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", io.EOF) + } +} + +func TestMuxChannelRequestUnblock(t *testing.T) { + a, b, connB := channelPair(t) + defer a.Close() + defer b.Close() + defer connB.Close() + + result := make(chan error, 1) + go func() { + _, err := a.SendRequest("hello", true, nil) + result <- err + }() + + <-b.incomingRequests + connB.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", err) + } +} + +func TestMuxCloseChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + defer r.Close() + defer w.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after Close", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxCloseWriteChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.CloseWrite(); err != nil { + t.Errorf("w.CloseWrite: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after CloseWrite", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxInvalidRecord(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + a.conn.writePacket(packet) + go a.SendRequest("hello", false, nil) + // 'a' wrote an invalid packet, so 'b' has exited. + req, ok := <-b.incomingRequests + if ok { + t.Errorf("got request %#v after receiving invalid packet", req) + } +} + +func TestZeroWindowAdjust(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + go func() { + io.WriteString(a, "hello") + // bogus adjust. + a.sendMessage(windowAdjustMsg{}) + io.WriteString(a, "world") + a.Close() + }() + + want := "helloworld" + c, _ := ioutil.ReadAll(b) + if string(c) != want { + t.Errorf("got %q want %q", c, want) + } +} + +func TestMuxMaxPacketSize(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + large := make([]byte, a.maxRemotePayload+1) + packet := make([]byte, 1+4+4+1+len(large)) + packet[0] = msgChannelData + marshalUint32(packet[1:], a.remoteId) + marshalUint32(packet[5:], uint32(len(large))) + packet[9] = 42 + + if err := a.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + + go a.SendRequest("hello", false, nil) + + _, ok := <-b.incomingRequests + if ok { + t.Errorf("connection still alive after receiving large packet.") + } +} + +// Don't ship code with debug=true. +func TestDebug(t *testing.T) { + if debugMux { + t.Error("mux debug switched on") + } + if debugHandshake { + t.Error("handshake debug switched on") + } +} diff --git a/vendor/golang.org/x/crypto/ssh/server.go b/vendor/golang.org/x/crypto/ssh/server.go new file mode 100644 index 0000000..37df1b3 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/server.go @@ -0,0 +1,488 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" +) + +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a +// user. Permissions, except for "source-address", must be enforced in +// the server application layer, after successful authentication. The +// Permissions are passed on in ServerConn so a server implementation +// can honor them. +type Permissions struct { + // Critical options restrict default permissions. Common + // restrictions are "source-address" and "force-command". If + // the server cannot enforce the restriction, or does not + // recognize it, the user should not authenticate. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Common extensions are + // "permit-agent-forwarding", "permit-X11-forwarding". Lack of + // support for an extension does not preclude authenticating a + // user. + Extensions map[string]string +} + +// ServerConfig holds server specific configuration data. +type ServerConfig struct { + // Config contains configuration shared between client and server. + Config + + hostKeys []Signer + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + NoClientAuth bool + + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback, if non-nil, is called when a client attempts public + // key authentication. It must return true if the given public key is + // valid for the given user. For example, see CertChecker.Authenticate. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback, if non-nil, is called when + // keyboard-interactive authentication is selected (RFC + // 4256). The client object's Challenge function should be + // used to query the user. The callback may offer multiple + // Challenge rounds. To avoid information leaks, the client + // should be presented a challenge even if the user is + // unknown. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) + + // ServerVersion is the version identification string to announce in + // the public handshake. + // If empty, a reasonable default is used. + // Note that RFC 4253 section 4.2 requires that this string start with + // "SSH-2.0-". + ServerVersion string +} + +// AddHostKey adds a private key as a host key. If an existing host +// key exists with the same algorithm, it is overwritten. Each server +// config must have at least one host key. +func (s *ServerConfig) AddHostKey(key Signer) { + for i, k := range s.hostKeys { + if k.PublicKey().Type() == key.PublicKey().Type() { + s.hostKeys[i] = key + return + } + } + + s.hostKeys = append(s.hostKeys, key) +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. +type cachedPubKey struct { + user string + pubKeyData []byte + result error + perms *Permissions +} + +const maxCachedPubKeys = 16 + +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} + +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { + for _, k := range c.keys { + if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { + return k, true + } + } + return cachedPubKey{}, false +} + +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) < maxCachedPubKeys { + c.keys = append(c.keys, candidate) + } +} + +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn + + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions +} + +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + s := &connection{ + sshConn: sshConn{conn: c}, + } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil +} + +// signAndMarshal signs the data with the appropriate algorithm, +// and serializes the result in SSH wire format. +func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { + sig, err := k.Sign(rand, data) + if err != nil { + return nil, err + } + + return Marshal(sig), nil +} + +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } + + if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if config.ServerVersion != "" { + s.serverVersion = []byte(config.ServerVersion) + } else { + s.serverVersion = []byte(packageVersion) + } + var err error + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) + if err != nil { + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.requestInitialKeyChange(); err != nil { + return nil, err + } + + // We just did the key change, so the session ID is established. + s.sessionID = s.transport.getSessionID() + + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + + perms, err := s.serverAuthenticate(config) + if err != nil { + return nil, err + } + s.mux = newMux(s.transport) + return perms, err +} + +func isAcceptableAlgo(algo string) bool { + switch algo { + case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519, + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: + return true + } + return false +} + +func checkSourceAddress(addr net.Addr, sourceAddr string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) + } + + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if bytes.Equal(allowedIP, tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil + } + } + + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) +} + +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + var err error + var cache pubKeyCache + var perms *Permissions + +userAuthLoop: + for { + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err + } + + if userAuthReq.Service != serviceSSH { + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + s.user = userAuthReq.User + perms = nil + authErr := errors.New("no auth passed yet") + + switch userAuthReq.Method { + case "none": + if config.NoClientAuth { + authErr = nil + } + case "password": + if config.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return nil, parseError(msgUserAuthRequest) + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + perms, authErr = config.PasswordCallback(s, password) + case "keyboard-interactive": + if config.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configubred") + break + } + + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) + case "publickey": + if config.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return nil, parseError(msgUserAuthRequest) + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + algo := string(algoBytes) + if !isAcceptableAlgo(algo) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } + + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + + candidate, ok := cache.get(s.user, pubKeyData) + if !ok { + candidate.user = s.user + candidate.pubKeyData = pubKeyData + candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) + if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + candidate.result = checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]) + } + cache.add(candidate) + } + + if isQuery { + // The client can query if the given public key + // would be okay. + if len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + if candidate.result == nil { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: pubKeyData, + } + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err + } + continue userAuthLoop + } + authErr = candidate.result + } else { + sig, payload, ok := parseSignature(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + // Ensure the public key algo and signature algo + // are supported. Compare the private key + // algorithm name that corresponds to algo with + // sig.Format. This is usually the same, but + // for certs, the names differ. + if !isAcceptableAlgo(sig.Format) { + break + } + signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) + + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err + } + + authErr = candidate.result + perms = candidate.perms + } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + if authErr == nil { + break userAuthLoop + } + + var failureMsg userAuthFailureMsg + if config.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if config.PublicKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + if config.KeyboardInteractiveCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") + } + + if len(failureMsg.Methods) == 0 { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err + } + } + + if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err + } + return perms, nil +} + +// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by +// asking the client on the other side of a ServerConn. +type sshClientKeyboardInteractive struct { + *connection +} + +func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, errors.New("ssh: echos and questions must have equal length") + } + + var prompts []byte + for i := range questions { + prompts = appendString(prompts, questions[i]) + prompts = appendBool(prompts, echos[i]) + } + + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ + Instruction: instruction, + NumPrompts: uint32(len(questions)), + Prompts: prompts, + })); err != nil { + return nil, err + } + + packet, err := c.transport.readPacket() + if err != nil { + return nil, err + } + if packet[0] != msgUserAuthInfoResponse { + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) + } + packet = packet[1:] + + n, packet, ok := parseUint32(packet) + if !ok || int(n) != len(questions) { + return nil, parseError(msgUserAuthInfoResponse) + } + + for i := uint32(0); i < n; i++ { + ans, rest, ok := parseString(packet) + if !ok { + return nil, parseError(msgUserAuthInfoResponse) + } + + answers = append(answers, string(ans)) + packet = rest + } + if len(packet) != 0 { + return nil, errors.New("ssh: junk at end of message") + } + + return answers, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/session.go b/vendor/golang.org/x/crypto/ssh/session.go new file mode 100644 index 0000000..17e2aa8 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/session.go @@ -0,0 +1,627 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "sync" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +var signals = map[Signal]int{ + SIGABRT: 6, + SIGALRM: 14, + SIGFPE: 8, + SIGHUP: 1, + SIGILL: 4, + SIGINT: 2, + SIGKILL: 9, + SIGPIPE: 13, + SIGQUIT: 3, + SIGSEGV: 11, + SIGTERM: 15, +} + +type TerminalModes map[uint8]uint32 + +// POSIX terminal mode flags as listed in RFC 4254 Section 8. +const ( + tty_OP_END = 0 + VINTR = 1 + VQUIT = 2 + VERASE = 3 + VKILL = 4 + VEOF = 5 + VEOL = 6 + VEOL2 = 7 + VSTART = 8 + VSTOP = 9 + VSUSP = 10 + VDSUSP = 11 + VREPRINT = 12 + VWERASE = 13 + VLNEXT = 14 + VFLUSH = 15 + VSWTCH = 16 + VSTATUS = 17 + VDISCARD = 18 + IGNPAR = 30 + PARMRK = 31 + INPCK = 32 + ISTRIP = 33 + INLCR = 34 + IGNCR = 35 + ICRNL = 36 + IUCLC = 37 + IXON = 38 + IXANY = 39 + IXOFF = 40 + IMAXBEL = 41 + ISIG = 50 + ICANON = 51 + XCASE = 52 + ECHO = 53 + ECHOE = 54 + ECHOK = 55 + ECHONL = 56 + NOFLSH = 57 + TOSTOP = 58 + IEXTEN = 59 + ECHOCTL = 60 + ECHOKE = 61 + PENDIN = 62 + OPOST = 70 + OLCUC = 71 + ONLCR = 72 + OCRNL = 73 + ONOCR = 74 + ONLRET = 75 + CS7 = 90 + CS8 = 91 + PARENB = 92 + PARODD = 93 + TTY_OP_ISPEED = 128 + TTY_OP_OSPEED = 129 +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Stdin specifies the remote process's standard input. + // If Stdin is nil, the remote process reads from an empty + // bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr specify the remote process's standard + // output and error. + // + // If either is nil, Run connects the corresponding file + // descriptor to an instance of ioutil.Discard. There is a + // fixed amount of buffering that is shared for the two streams. + // If either blocks it may eventually cause the remote + // command to block. + Stdout io.Writer + Stderr io.Writer + + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. + copyFuncs []func() error + errors chan error // one send per copyFunc + + // true if pipe method is active + stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) +} + +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(name, value string) error { + msg := setenvRequest{ + Name: name, + Value: value, + } + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") + } + return err +} + +// RFC 4254 Section 6.2. +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { + var tm []byte + for k, v := range termmodes { + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) + } + tm = append(tm, tty_OP_END) + req := ptyRequestMsg{ + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), + } + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") + } + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated +func (s *Session) RequestSubsystem(subsystem string) error { + msg := subsystemRequestMsg{ + Subsystem: subsystem, + } + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") + } + return err +} + +// RFC 4254 Section 6.9. +type signalMsg struct { + Signal string +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(sig Signal) error { + msg := signalMsg{ + Signal: string(sig), + } + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err +} + +// RFC 4254 Section 6.5. +type execMsg struct { + Command string +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(cmd string) error { + if s.started { + return errors.New("ssh: session already started") + } + req := execMsg{ + Command: cmd, + } + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) + } + if err != nil { + return err + } + return s.start() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Run(cmd string) error { + err := s.Start(cmd) + if err != nil { + return err + } + return s.Wait() +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + var b bytes.Buffer + s.Stdout = &b + err := s.Run(cmd) + return b.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + var b singleWriter + s.Stdout = &b + s.Stderr = &b + err := s.Run(cmd) + return b.b.Bytes(), err +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell() error { + if s.started { + return errors.New("ssh: session already started") + } + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return errors.New("ssh: could not start shell") + } + if err != nil { + return err + } + return s.start() +} + +func (s *Session) start() error { + s.started = true + + type F func(*Session) + for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { + setupFd(s) + } + + s.errors = make(chan error, len(s.copyFuncs)) + for _, fn := range s.copyFuncs { + go func(fn func() error) { + s.errors <- fn() + }(fn) + } + return nil +} + +// Wait waits for the remote command to exit. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Wait() error { + if !s.started { + return errors.New("ssh: session not started") + } + waitErr := <-s.exitStatus + + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } + var copyError error + for _ = range s.copyFuncs { + if err := <-s.errors; err != nil && copyError == nil { + copyError = err + } + } + if waitErr != nil { + return waitErr + } + return copyError +} + +func (s *Session) wait(reqs <-chan *Request) error { + wm := Waitmsg{status: -1} + // Wait for msg channel to be closed before returning. + for msg := range reqs { + switch msg.Type { + case "exit-status": + wm.status = int(binary.BigEndian.Uint32(msg.Payload)) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err + } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang + default: + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + if wm.status == 0 { + return nil + } + if wm.status == -1 { + // exit-status was never sent from server + if wm.signal == "" { + // signal was not sent either. RFC 4254 + // section 6.10 recommends against this + // behavior, but it is allowed, so we let + // clients handle it. + return &ExitMissingError{} + } + wm.status = 128 + if _, ok := signals[Signal(wm.signal)]; ok { + wm.status += signals[Signal(wm.signal)] + } + } + + return &ExitError{wm} +} + +// ExitMissingError is returned if a session is torn down cleanly, but +// the server sends no confirmation of the exit status. +type ExitMissingError struct{} + +func (e *ExitMissingError) Error() string { + return "wait: remote command exited without exit status or exit signal" +} + +func (s *Session) stdin() { + if s.stdinpipe { + return + } + var stdin io.Reader + if s.Stdin == nil { + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { + err = err1 + } + return err + }) +} + +func (s *Session) stdout() { + if s.stdoutpipe { + return + } + if s.Stdout == nil { + s.Stdout = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stdout, s.ch) + return err + }) +} + +func (s *Session) stderr() { + if s.stderrpipe { + return + } + if s.Stderr == nil { + s.Stderr = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stderr, s.ch.Stderr()) + return err + }) +} + +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + s.stdinpipe = true + return &sessionStdin{s.ch, s.ch}, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.Reader, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + s.stdoutpipe = true + return s.ch, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.Reader, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + s.stderrpipe = true + return s.ch.Stderr(), nil +} + +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, + } + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil +} + +// An ExitError reports unsuccessful completion of a remote command. +type ExitError struct { + Waitmsg +} + +func (e *ExitError) Error() string { + return e.Waitmsg.String() +} + +// Waitmsg stores the information about an exited remote command +// as reported by Wait. +type Waitmsg struct { + status int + signal string + msg string + lang string +} + +// ExitStatus returns the exit status of the remote command. +func (w Waitmsg) ExitStatus() int { + return w.status +} + +// Signal returns the exit signal of the remote command if +// it was terminated violently. +func (w Waitmsg) Signal() string { + return w.signal +} + +// Msg returns the exit message given by the remote command +func (w Waitmsg) Msg() string { + return w.msg +} + +// Lang returns the language tag. See RFC 3066 +func (w Waitmsg) Lang() string { + return w.lang +} + +func (w Waitmsg) String() string { + str := fmt.Sprintf("Process exited with status %v", w.status) + if w.signal != "" { + str += fmt.Sprintf(" from signal %v", w.signal) + } + if w.msg != "" { + str += fmt.Sprintf(". Reason was: %v", w.msg) + } + return str +} diff --git a/vendor/golang.org/x/crypto/ssh/session_test.go b/vendor/golang.org/x/crypto/ssh/session_test.go new file mode 100644 index 0000000..f35a378 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/session_test.go @@ -0,0 +1,770 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session tests. + +import ( + "bytes" + crypto_rand "crypto/rand" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "testing" + + "golang.org/x/crypto/ssh/terminal" +) + +type serverType func(Channel, <-chan *Request, *testing.T) + +// dial constructs a new test server and returns a *ClientConn. +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + go func() { + defer c1.Close() + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(testSigners["rsa"]) + + _, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Fatalf("Unable to handshake: %v", err) + } + go DiscardRequests(reqs) + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue + } + + ch, inReqs, err := newCh.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + continue + } + go func() { + handler(ch, inReqs, t) + }() + } + }() + + config := &ClientConfig{ + User: "testuser", + } + + conn, chans, reqs, err := NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("unable to dial remote side: %v", err) + } + + return NewClient(conn, chans, reqs) +} + +// Test a simple string is returned to session.Stdout. +func TestSessionShell(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout := new(bytes.Buffer) + session.Stdout = stdout + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + actual := stdout.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. + +// Test a simple string is returned via StdoutPipe. +func TestSessionStdoutPipe(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("Unable to request StdoutPipe(): %v", err) + } + var buf bytes.Buffer + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + done := make(chan bool, 1) + go func() { + if _, err := io.Copy(&buf, stdout); err != nil { + t.Errorf("Copy of stdout failed: %v", err) + } + done <- true + }() + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + <-done + actual := buf.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// Test that a simple string is returned via the Output helper, +// and that stderr is discarded. +func TestSessionOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.Output("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + w := "this-is-stdout." + g := string(buf) + if g != w { + t.Error("Remote command did not return expected string:") + t.Logf("want %q", w) + t.Logf("got %q", g) + } +} + +// Test that both stdout and stderr are returned +// via the CombinedOutput helper. +func TestSessionCombinedOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + const stdout = "this-is-stdout." + const stderr = "this-is-stderr." + g := string(buf) + if g != stdout+stderr && g != stderr+stdout { + t.Error("Remote command did not return expected string:") + t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) + t.Logf("got %q", g) + } +} + +// Test non-0 exit status is returned correctly. +func TestExitStatusNonZero(t *testing.T) { + conn := dial(exitStatusNonZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) + } +} + +// Test 0 exit status is returned correctly. +func TestExitStatusZero(t *testing.T) { + conn := dial(exitStatusZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +// Test exit signal and status are both returned correctly. +func TestExitSignalAndStatus(t *testing.T) { + conn := dial(exitSignalAndStatusHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestKnownExitSignalOnly(t *testing.T) { + conn := dial(exitSignalHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 143 { + t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestUnknownExitSignal(t *testing.T) { + conn := dial(exitSignalUnknownHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "SYS" || e.ExitStatus() != 128 { + t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +func TestExitWithoutStatusOrSignal(t *testing.T) { + conn := dial(exitWithoutSignalOrStatus, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + if _, ok := err.(*ExitMissingError); !ok { + t.Fatalf("got %T want *ExitMissingError", err) + } +} + +// windowTestBytes is the number of bytes that we'll send to the SSH server. +const windowTestBytes = 16000 * 200 + +// TestServerWindow writes random data to the server. The server is expected to echo +// the same data back, which is compared against the original. +func TestServerWindow(t *testing.T) { + origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) + origBytes := origBuf.Bytes() + + conn := dial(echoHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + result := make(chan []byte) + + go func() { + defer close(result) + echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + serverStdout, err := session.StdoutPipe() + if err != nil { + t.Errorf("StdoutPipe failed: %v", err) + return + } + n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) + if err != nil && err != io.EOF { + t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) + } + result <- echoedBuf.Bytes() + }() + + serverStdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) + if err != nil { + t.Fatalf("failed to copy origBuf to serverStdin: %v", err) + } + if written != windowTestBytes { + t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes) + } + + echoedBytes := <-result + + if !bytes.Equal(origBytes, echoedBytes) { + t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) + } +} + +// Verify the client can handle a keepalive packet from the server. +func TestClientHandlesKeepalives(t *testing.T) { + conn := dial(channelKeepaliveSender, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got: %v", err) + } +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Errmsg string + Lang string +} + +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) + } +} + +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(0, ch, t) +} + +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) +} + +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) + sendSignal("TERM", ch, t) +} + +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("TERM", ch, t) +} + +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("SYS", ch, t) +} + +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) +} + +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "golang") + readLine(shell, t) + sendStatus(0, ch, t) +} + +// Ignores the command, writes fixed strings to stderr and stdout. +// Strings are "this-is-stdout." and "this-is-stderr.". +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + _, err := ch.Read(nil) + + req, ok := <-in + if !ok { + t.Fatalf("error: expected channel request, got: %#v", err) + return + } + + // ignore request, always send some text + req.Reply(true, nil) + + _, err = io.WriteString(ch, "this-is-stdout.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + sendStatus(0, ch, t) +} + +func readLine(shell *terminal.Terminal, t *testing.T) { + if _, err := shell.ReadLine(); err != nil && err != io.EOF { + t.Errorf("unable to read line: %v", err) + } +} + +func sendStatus(status uint32, ch Channel, t *testing.T) { + msg := exitStatusMsg{ + Status: status, + } + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { + t.Errorf("unable to send status: %v", err) + } +} + +func sendSignal(signal string, ch Channel, t *testing.T) { + sig := exitSignalMsg{ + Signal: signal, + CoreDumped: false, + Errmsg: "Process terminated", + Lang: "en-GB-oed", + } + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { + t.Errorf("unable to send signal: %v", err) + } +} + +func discardHandler(ch Channel, t *testing.T) { + defer ch.Close() + io.Copy(ioutil.Discard, ch) +} + +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { + t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) + } +} + +// copyNRandomly copies n bytes from src to dst. It uses a variable, and random, +// buffer size to exercise more code paths. +func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { + var ( + buf = make([]byte, 32*1024) + written int + remaining = n + ) + for remaining > 0 { + l := rand.Intn(1 << 15) + if remaining < l { + l = remaining + } + nr, er := src.Read(buf[:l]) + nw, ew := dst.Write(buf[:nr]) + remaining -= nw + written += nw + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + if er != nil && er != io.EOF { + return written, er + } + } + return written, nil +} + +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { + t.Errorf("unable to send channel keepalive request: %v", err) + } + sendStatus(0, ch, t) +} + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := ioutil.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := ioutil.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} + +func TestSessionID(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverID := make(chan []byte, 1) + clientID := make(chan []byte, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + User: "user", + } + + go func() { + conn, chans, reqs, err := NewServerConn(c1, serverConf) + if err != nil { + t.Fatalf("server handshake: %v", err) + } + serverID <- conn.SessionID() + go DiscardRequests(reqs) + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + go func() { + conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + clientID <- conn.SessionID() + go DiscardRequests(reqs) + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + s := <-serverID + c := <-clientID + if bytes.Compare(s, c) != 0 { + t.Errorf("server session ID (%x) != client session ID (%x)", s, c) + } else if len(s) == 0 { + t.Errorf("client and server SessionID were empty.") + } +} + +type noReadConn struct { + readSeen bool + net.Conn +} + +func (c *noReadConn) Close() error { + return nil +} + +func (c *noReadConn) Read(b []byte) (int, error) { + c.readSeen = true + return 0, errors.New("noReadConn error") +} + +func TestInvalidServerConfiguration(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serveConn := noReadConn{Conn: c1} + serverConf := &ServerConfig{} + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") + } + + serverConf.AddHostKey(testSigners["ecdsa"]) + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") + } +} + +func TestHostKeyAlgorithms(t *testing.T) { + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.AddHostKey(testSigners["ecdsa"]) + + connect := func(clientConf *ClientConfig, want string) { + var alg string + clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { + alg = key.Type() + return nil + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewServerConn(c1, serverConf) + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + if alg != want { + t.Errorf("selected key algorithm %s, want %s", alg, want) + } + } + + // By default, we get the preferred algorithm, which is ECDSA 256. + + clientConf := &ClientConfig{} + connect(clientConf, KeyAlgoECDSA256) + + // Client asks for RSA explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} + connect(clientConf, KeyAlgoRSA) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewServerConn(c1, serverConf) + clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with unknown hostkey algorithm") + } +} diff --git a/vendor/golang.org/x/crypto/ssh/tcpip.go b/vendor/golang.org/x/crypto/ssh/tcpip.go new file mode 100644 index 0000000..6151241 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/tcpip.go @@ -0,0 +1,407 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +func (c *Client) Listen(n, addr string) (net.Listener, error) { + laddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } + return c.ListenTCP(laddr) +} + +// Automatic port allocation is broken with OpenSSH before 6.0. See +// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In +// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, +// rather than the actual port number. This means you can never open +// two different listeners with auto allocated ports. We work around +// this by trying explicit ports until we succeed. + +const openSSHPrefix = "OpenSSH_" + +var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) + +// isBrokenOpenSSHVersion returns true if the given version string +// specifies a version of OpenSSH that is known to have a bug in port +// forwarding. +func isBrokenOpenSSHVersion(versionStr string) bool { + i := strings.Index(versionStr, openSSHPrefix) + if i < 0 { + return false + } + i += len(openSSHPrefix) + j := i + for ; j < len(versionStr); j++ { + if versionStr[j] < '0' || versionStr[j] > '9' { + break + } + } + version, _ := strconv.Atoi(versionStr[i:j]) + return version < 6 +} + +// autoPortListenWorkaround simulates automatic port allocation by +// trying random ports repeatedly. +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { + var sshListener net.Listener + var err error + const tries = 10 + for i := 0; i < tries; i++ { + addr := *laddr + addr.Port = 1024 + portRandomizer.Intn(60000) + sshListener, err = c.ListenTCP(&addr) + if err == nil { + laddr.Port = addr.Port + return sshListener, err + } + } + return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) +} + +// RFC 4254 7.1 +type channelForwardMsg struct { + addr string + rport uint32 +} + +// ListenTCP requests the remote peer open a listening socket +// on laddr. Incoming connections will be available by calling +// Accept on the returned net.Listener. +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { + return c.autoPortListenWorkaround(laddr) + } + + m := channelForwardMsg{ + laddr.IP.String(), + uint32(laddr.Port), + } + // send message + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + // If the original port was 0, then the remote side will + // supply a real port number in the response. + if laddr.Port == 0 { + var p struct { + Port uint32 + } + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) + } + + // Register this forward, using the port number we obtained. + ch := c.forwards.add(*laddr) + + return &tcpListener{laddr, c, ch}, nil +} + +// forwardList stores a mapping between remote +// forward requests and the tcpListeners. +type forwardList struct { + sync.Mutex + entries []forwardEntry +} + +// forwardEntry represents an established mapping of a laddr on a +// remote ssh server to a channel connected to a tcpListener. +type forwardEntry struct { + laddr net.TCPAddr + c chan forward +} + +// forward represents an incoming forwarded tcpip connection. The +// arguments to add/remove/lookup should be address as specified in +// the original forward-request. +type forward struct { + newCh NewChannel // the ssh client channel underlying this forward + raddr *net.TCPAddr // the raddr of the incoming connection +} + +func (l *forwardList) add(addr net.TCPAddr) chan forward { + l.Lock() + defer l.Unlock() + f := forwardEntry{ + addr, + make(chan forward, 1), + } + l.entries = append(l.entries, f) + return f.c +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. +func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { + if port == 0 || port > 65535 { + return nil, fmt.Errorf("ssh: port number out of range: %d", port) + } + ip := net.ParseIP(string(addr)) + if ip == nil { + return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) + } + return &net.TCPAddr{IP: ip, Port: int(port)}, nil +} + +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + var payload forwardedTCPPayload + if err := Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) + continue + } + + // RFC 4254 section 7.2 specifies that incoming + // addresses should list the address, in string + // format. It is implied that this should be an IP + // address, as it would be impossible to connect to it + // otherwise. + laddr, err := parseTCPAddr(payload.Addr, payload.Port) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + + if ok := l.forward(*laddr, *raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + } +} + +// remove removes the forward entry, and the channel feeding its +// listener. +func (l *forwardList) remove(addr net.TCPAddr) { + l.Lock() + defer l.Unlock() + for i, f := range l.entries { + if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) + return + } + } +} + +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + +func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { + f.c <- forward{ch, &raddr} + return true + } + } + return false +} + +type tcpListener struct { + laddr *net.TCPAddr + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *tcpListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &tcpChanConn{ + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, + }, nil +} + +// Close closes the listener. +func (l *tcpListener) Close() error { + m := channelForwardMsg{ + l.laddr.IP.String(), + uint32(l.laddr.Port), + } + + // this also closes the listener. + l.conn.forwards.remove(*l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.laddr +} + +// Dial initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) Dial(n, addr string) (net.Conn, error) { + // Parse the address into host and numeric port. + host, portString, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + // Use a zero address for local and remote address. + zeroAddr := &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) + if err != nil { + return nil, err + } + return &tcpChanConn{ + Channel: ch, + laddr: zeroAddr, + raddr: zeroAddr, + }, nil +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { + if laddr == nil { + laddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + if err != nil { + return nil, err + } + return &tcpChanConn{ + Channel: ch, + laddr: laddr, + raddr: raddr, + }, nil +} + +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + raddr string + rport uint32 + laddr string + lport uint32 +} + +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), + } + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type tcpChan struct { + Channel // the backing channel +} + +// tcpChanConn fulfills the net.Conn interface without +// the tcpChan having to hold laddr or raddr directly. +type tcpChanConn struct { + Channel + laddr, raddr net.Addr +} + +// LocalAddr returns the local network address. +func (t *tcpChanConn) LocalAddr() net.Addr { + return t.laddr +} + +// RemoteAddr returns the remote network address. +func (t *tcpChanConn) RemoteAddr() net.Addr { + return t.raddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. +func (t *tcpChanConn) SetDeadline(deadline time.Time) error { + if err := t.SetReadDeadline(deadline); err != nil { + return err + } + return t.SetWriteDeadline(deadline) +} + +// SetReadDeadline sets the read deadline. +// A zero value for t means Read will not time out. +// After the deadline, the error from Read will implement net.Error +// with Timeout() == true. +func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} + +// SetWriteDeadline exists to satisfy the net.Conn interface +// but is not implemented by this type. It always returns an error. +func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} diff --git a/vendor/golang.org/x/crypto/ssh/tcpip_test.go b/vendor/golang.org/x/crypto/ssh/tcpip_test.go new file mode 100644 index 0000000..f1265cb --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/tcpip_test.go @@ -0,0 +1,20 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "testing" +) + +func TestAutoPortListenBroken(t *testing.T) { + broken := "SSH-2.0-OpenSSH_5.9hh11" + works := "SSH-2.0-OpenSSH_6.1" + if !isBrokenOpenSSHVersion(broken) { + t.Errorf("version %q not marked as broken", broken) + } + if isBrokenOpenSSHVersion(works) { + t.Errorf("version %q marked as broken", works) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/terminal/terminal.go b/vendor/golang.org/x/crypto/ssh/terminal/terminal.go new file mode 100644 index 0000000..5ea89a2 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/terminal.go @@ -0,0 +1,924 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import ( + "bytes" + "io" + "sync" + "unicode/utf8" +) + +// EscapeCodes contains escape sequences that can be written to the terminal in +// order to achieve different styles of text. +type EscapeCodes struct { + // Foreground colors + Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte + + // Reset all attributes + Reset []byte +} + +var vt100EscapeCodes = EscapeCodes{ + Black: []byte{keyEscape, '[', '3', '0', 'm'}, + Red: []byte{keyEscape, '[', '3', '1', 'm'}, + Green: []byte{keyEscape, '[', '3', '2', 'm'}, + Yellow: []byte{keyEscape, '[', '3', '3', 'm'}, + Blue: []byte{keyEscape, '[', '3', '4', 'm'}, + Magenta: []byte{keyEscape, '[', '3', '5', 'm'}, + Cyan: []byte{keyEscape, '[', '3', '6', 'm'}, + White: []byte{keyEscape, '[', '3', '7', 'm'}, + + Reset: []byte{keyEscape, '[', '0', 'm'}, +} + +// Terminal contains the state for running a VT100 terminal that is capable of +// reading lines of input. +type Terminal struct { + // AutoCompleteCallback, if non-null, is called for each keypress with + // the full input line and the current position of the cursor (in + // bytes, as an index into |line|). If it returns ok=false, the key + // press is processed normally. Otherwise it returns a replacement line + // and the new cursor position. + AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) + + // Escape contains a pointer to the escape codes for this terminal. + // It's always a valid pointer, although the escape codes themselves + // may be empty if the terminal doesn't support them. + Escape *EscapeCodes + + // lock protects the terminal and the state in this object from + // concurrent processing of a key press and a Write() call. + lock sync.Mutex + + c io.ReadWriter + prompt []rune + + // line is the current line being entered. + line []rune + // pos is the logical position of the cursor in line + pos int + // echo is true if local echo is enabled + echo bool + // pasteActive is true iff there is a bracketed paste operation in + // progress. + pasteActive bool + + // cursorX contains the current X value of the cursor where the left + // edge is 0. cursorY contains the row number where the first row of + // the current line is 0. + cursorX, cursorY int + // maxLine is the greatest value of cursorY so far. + maxLine int + + termWidth, termHeight int + + // outBuf contains the terminal data to be sent. + outBuf []byte + // remainder contains the remainder of any partial key sequences after + // a read. It aliases into inBuf. + remainder []byte + inBuf [256]byte + + // history contains previously entered commands so that they can be + // accessed with the up and down keys. + history stRingBuffer + // historyIndex stores the currently accessed history entry, where zero + // means the immediately previous entry. + historyIndex int + // When navigating up and down the history it's possible to return to + // the incomplete, initial line. That value is stored in + // historyPending. + historyPending string +} + +// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is +// a local terminal, that terminal must first have been put into raw mode. +// prompt is a string that is written at the start of each input line (i.e. +// "> "). +func NewTerminal(c io.ReadWriter, prompt string) *Terminal { + return &Terminal{ + Escape: &vt100EscapeCodes, + c: c, + prompt: []rune(prompt), + termWidth: 80, + termHeight: 24, + echo: true, + historyIndex: -1, + } +} + +const ( + keyCtrlD = 4 + keyCtrlU = 21 + keyEnter = '\r' + keyEscape = 27 + keyBackspace = 127 + keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota + keyUp + keyDown + keyLeft + keyRight + keyAltLeft + keyAltRight + keyHome + keyEnd + keyDeleteWord + keyDeleteLine + keyClearScreen + keyPasteStart + keyPasteEnd +) + +var ( + crlf = []byte{'\r', '\n'} + pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} + pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} +) + +// bytesToKey tries to parse a key sequence from b. If successful, it returns +// the key and the remainder of the input. Otherwise it returns utf8.RuneError. +func bytesToKey(b []byte, pasteActive bool) (rune, []byte) { + if len(b) == 0 { + return utf8.RuneError, nil + } + + if !pasteActive { + switch b[0] { + case 1: // ^A + return keyHome, b[1:] + case 5: // ^E + return keyEnd, b[1:] + case 8: // ^H + return keyBackspace, b[1:] + case 11: // ^K + return keyDeleteLine, b[1:] + case 12: // ^L + return keyClearScreen, b[1:] + case 23: // ^W + return keyDeleteWord, b[1:] + } + } + + if b[0] != keyEscape { + if !utf8.FullRune(b) { + return utf8.RuneError, b + } + r, l := utf8.DecodeRune(b) + return r, b[l:] + } + + if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { + switch b[2] { + case 'A': + return keyUp, b[3:] + case 'B': + return keyDown, b[3:] + case 'C': + return keyRight, b[3:] + case 'D': + return keyLeft, b[3:] + case 'H': + return keyHome, b[3:] + case 'F': + return keyEnd, b[3:] + } + } + + if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { + switch b[5] { + case 'C': + return keyAltRight, b[6:] + case 'D': + return keyAltLeft, b[6:] + } + } + + if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) { + return keyPasteStart, b[6:] + } + + if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) { + return keyPasteEnd, b[6:] + } + + // If we get here then we have a key that we don't recognise, or a + // partial sequence. It's not clear how one should find the end of a + // sequence without knowing them all, but it seems that [a-zA-Z~] only + // appears at the end of a sequence. + for i, c := range b[0:] { + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' { + return keyUnknown, b[i+1:] + } + } + + return utf8.RuneError, b +} + +// queue appends data to the end of t.outBuf +func (t *Terminal) queue(data []rune) { + t.outBuf = append(t.outBuf, []byte(string(data))...) +} + +var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'} +var space = []rune{' '} + +func isPrintable(key rune) bool { + isInSurrogateArea := key >= 0xd800 && key <= 0xdbff + return key >= 32 && !isInSurrogateArea +} + +// moveCursorToPos appends data to t.outBuf which will move the cursor to the +// given, logical position in the text. +func (t *Terminal) moveCursorToPos(pos int) { + if !t.echo { + return + } + + x := visualLength(t.prompt) + pos + y := x / t.termWidth + x = x % t.termWidth + + up := 0 + if y < t.cursorY { + up = t.cursorY - y + } + + down := 0 + if y > t.cursorY { + down = y - t.cursorY + } + + left := 0 + if x < t.cursorX { + left = t.cursorX - x + } + + right := 0 + if x > t.cursorX { + right = x - t.cursorX + } + + t.cursorX = x + t.cursorY = y + t.move(up, down, left, right) +} + +func (t *Terminal) move(up, down, left, right int) { + movement := make([]rune, 3*(up+down+left+right)) + m := movement + for i := 0; i < up; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'A' + m = m[3:] + } + for i := 0; i < down; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'B' + m = m[3:] + } + for i := 0; i < left; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'D' + m = m[3:] + } + for i := 0; i < right; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'C' + m = m[3:] + } + + t.queue(movement) +} + +func (t *Terminal) clearLineToRight() { + op := []rune{keyEscape, '[', 'K'} + t.queue(op) +} + +const maxLineLength = 4096 + +func (t *Terminal) setLine(newLine []rune, newPos int) { + if t.echo { + t.moveCursorToPos(0) + t.writeLine(newLine) + for i := len(newLine); i < len(t.line); i++ { + t.writeLine(space) + } + t.moveCursorToPos(newPos) + } + t.line = newLine + t.pos = newPos +} + +func (t *Terminal) advanceCursor(places int) { + t.cursorX += places + t.cursorY += t.cursorX / t.termWidth + if t.cursorY > t.maxLine { + t.maxLine = t.cursorY + } + t.cursorX = t.cursorX % t.termWidth + + if places > 0 && t.cursorX == 0 { + // Normally terminals will advance the current position + // when writing a character. But that doesn't happen + // for the last character in a line. However, when + // writing a character (except a new line) that causes + // a line wrap, the position will be advanced two + // places. + // + // So, if we are stopping at the end of a line, we + // need to write a newline so that our cursor can be + // advanced to the next line. + t.outBuf = append(t.outBuf, '\r', '\n') + } +} + +func (t *Terminal) eraseNPreviousChars(n int) { + if n == 0 { + return + } + + if t.pos < n { + n = t.pos + } + t.pos -= n + t.moveCursorToPos(t.pos) + + copy(t.line[t.pos:], t.line[n+t.pos:]) + t.line = t.line[:len(t.line)-n] + if t.echo { + t.writeLine(t.line[t.pos:]) + for i := 0; i < n; i++ { + t.queue(space) + } + t.advanceCursor(n) + t.moveCursorToPos(t.pos) + } +} + +// countToLeftWord returns then number of characters from the cursor to the +// start of the previous word. +func (t *Terminal) countToLeftWord() int { + if t.pos == 0 { + return 0 + } + + pos := t.pos - 1 + for pos > 0 { + if t.line[pos] != ' ' { + break + } + pos-- + } + for pos > 0 { + if t.line[pos] == ' ' { + pos++ + break + } + pos-- + } + + return t.pos - pos +} + +// countToRightWord returns then number of characters from the cursor to the +// start of the next word. +func (t *Terminal) countToRightWord() int { + pos := t.pos + for pos < len(t.line) { + if t.line[pos] == ' ' { + break + } + pos++ + } + for pos < len(t.line) { + if t.line[pos] != ' ' { + break + } + pos++ + } + return pos - t.pos +} + +// visualLength returns the number of visible glyphs in s. +func visualLength(runes []rune) int { + inEscapeSeq := false + length := 0 + + for _, r := range runes { + switch { + case inEscapeSeq: + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + inEscapeSeq = false + } + case r == '\x1b': + inEscapeSeq = true + default: + length++ + } + } + + return length +} + +// handleKey processes the given key and, optionally, returns a line of text +// that the user has entered. +func (t *Terminal) handleKey(key rune) (line string, ok bool) { + if t.pasteActive && key != keyEnter { + t.addKeyToLine(key) + return + } + + switch key { + case keyBackspace: + if t.pos == 0 { + return + } + t.eraseNPreviousChars(1) + case keyAltLeft: + // move left by a word. + t.pos -= t.countToLeftWord() + t.moveCursorToPos(t.pos) + case keyAltRight: + // move right by a word. + t.pos += t.countToRightWord() + t.moveCursorToPos(t.pos) + case keyLeft: + if t.pos == 0 { + return + } + t.pos-- + t.moveCursorToPos(t.pos) + case keyRight: + if t.pos == len(t.line) { + return + } + t.pos++ + t.moveCursorToPos(t.pos) + case keyHome: + if t.pos == 0 { + return + } + t.pos = 0 + t.moveCursorToPos(t.pos) + case keyEnd: + if t.pos == len(t.line) { + return + } + t.pos = len(t.line) + t.moveCursorToPos(t.pos) + case keyUp: + entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1) + if !ok { + return "", false + } + if t.historyIndex == -1 { + t.historyPending = string(t.line) + } + t.historyIndex++ + runes := []rune(entry) + t.setLine(runes, len(runes)) + case keyDown: + switch t.historyIndex { + case -1: + return + case 0: + runes := []rune(t.historyPending) + t.setLine(runes, len(runes)) + t.historyIndex-- + default: + entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1) + if ok { + t.historyIndex-- + runes := []rune(entry) + t.setLine(runes, len(runes)) + } + } + case keyEnter: + t.moveCursorToPos(len(t.line)) + t.queue([]rune("\r\n")) + line = string(t.line) + ok = true + t.line = t.line[:0] + t.pos = 0 + t.cursorX = 0 + t.cursorY = 0 + t.maxLine = 0 + case keyDeleteWord: + // Delete zero or more spaces and then one or more characters. + t.eraseNPreviousChars(t.countToLeftWord()) + case keyDeleteLine: + // Delete everything from the current cursor position to the + // end of line. + for i := t.pos; i < len(t.line); i++ { + t.queue(space) + t.advanceCursor(1) + } + t.line = t.line[:t.pos] + t.moveCursorToPos(t.pos) + case keyCtrlD: + // Erase the character under the current position. + // The EOF case when the line is empty is handled in + // readLine(). + if t.pos < len(t.line) { + t.pos++ + t.eraseNPreviousChars(1) + } + case keyCtrlU: + t.eraseNPreviousChars(t.pos) + case keyClearScreen: + // Erases the screen and moves the cursor to the home position. + t.queue([]rune("\x1b[2J\x1b[H")) + t.queue(t.prompt) + t.cursorX, t.cursorY = 0, 0 + t.advanceCursor(visualLength(t.prompt)) + t.setLine(t.line, t.pos) + default: + if t.AutoCompleteCallback != nil { + prefix := string(t.line[:t.pos]) + suffix := string(t.line[t.pos:]) + + t.lock.Unlock() + newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key) + t.lock.Lock() + + if completeOk { + t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos])) + return + } + } + if !isPrintable(key) { + return + } + if len(t.line) == maxLineLength { + return + } + t.addKeyToLine(key) + } + return +} + +// addKeyToLine inserts the given key at the current position in the current +// line. +func (t *Terminal) addKeyToLine(key rune) { + if len(t.line) == cap(t.line) { + newLine := make([]rune, len(t.line), 2*(1+len(t.line))) + copy(newLine, t.line) + t.line = newLine + } + t.line = t.line[:len(t.line)+1] + copy(t.line[t.pos+1:], t.line[t.pos:]) + t.line[t.pos] = key + if t.echo { + t.writeLine(t.line[t.pos:]) + } + t.pos++ + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) writeLine(line []rune) { + for len(line) != 0 { + remainingOnLine := t.termWidth - t.cursorX + todo := len(line) + if todo > remainingOnLine { + todo = remainingOnLine + } + t.queue(line[:todo]) + t.advanceCursor(visualLength(line[:todo])) + line = line[todo:] + } +} + +// writeWithCRLF writes buf to w but replaces all occurances of \n with \r\n. +func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) { + for len(buf) > 0 { + i := bytes.IndexByte(buf, '\n') + todo := len(buf) + if i >= 0 { + todo = i + } + + var nn int + nn, err = w.Write(buf[:todo]) + n += nn + if err != nil { + return n, err + } + buf = buf[todo:] + + if i >= 0 { + if _, err = w.Write(crlf); err != nil { + return n, err + } + n += 1 + buf = buf[1:] + } + } + + return n, nil +} + +func (t *Terminal) Write(buf []byte) (n int, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.cursorX == 0 && t.cursorY == 0 { + // This is the easy case: there's nothing on the screen that we + // have to move out of the way. + return writeWithCRLF(t.c, buf) + } + + // We have a prompt and possibly user input on the screen. We + // have to clear it first. + t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */) + t.cursorX = 0 + t.clearLineToRight() + + for t.cursorY > 0 { + t.move(1 /* up */, 0, 0, 0) + t.cursorY-- + t.clearLineToRight() + } + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + + if n, err = writeWithCRLF(t.c, buf); err != nil { + return + } + + t.writeLine(t.prompt) + if t.echo { + t.writeLine(t.line) + } + + t.moveCursorToPos(t.pos) + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + return +} + +// ReadPassword temporarily changes the prompt and reads a password, without +// echo, from the terminal. +func (t *Terminal) ReadPassword(prompt string) (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + oldPrompt := t.prompt + t.prompt = []rune(prompt) + t.echo = false + + line, err = t.readLine() + + t.prompt = oldPrompt + t.echo = true + + return +} + +// ReadLine returns a line of input from the terminal. +func (t *Terminal) ReadLine() (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + return t.readLine() +} + +func (t *Terminal) readLine() (line string, err error) { + // t.lock must be held at this point + + if t.cursorX == 0 && t.cursorY == 0 { + t.writeLine(t.prompt) + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + } + + lineIsPasted := t.pasteActive + + for { + rest := t.remainder + lineOk := false + for !lineOk { + var key rune + key, rest = bytesToKey(rest, t.pasteActive) + if key == utf8.RuneError { + break + } + if !t.pasteActive { + if key == keyCtrlD { + if len(t.line) == 0 { + return "", io.EOF + } + } + if key == keyPasteStart { + t.pasteActive = true + if len(t.line) == 0 { + lineIsPasted = true + } + continue + } + } else if key == keyPasteEnd { + t.pasteActive = false + continue + } + if !t.pasteActive { + lineIsPasted = false + } + line, lineOk = t.handleKey(key) + } + if len(rest) > 0 { + n := copy(t.inBuf[:], rest) + t.remainder = t.inBuf[:n] + } else { + t.remainder = nil + } + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + if lineOk { + if t.echo { + t.historyIndex = -1 + t.history.Add(line) + } + if lineIsPasted { + err = ErrPasteIndicator + } + return + } + + // t.remainder is a slice at the beginning of t.inBuf + // containing a partial key sequence + readBuf := t.inBuf[len(t.remainder):] + var n int + + t.lock.Unlock() + n, err = t.c.Read(readBuf) + t.lock.Lock() + + if err != nil { + return + } + + t.remainder = t.inBuf[:n+len(t.remainder)] + } + + panic("unreachable") // for Go 1.0. +} + +// SetPrompt sets the prompt to be used when reading subsequent lines. +func (t *Terminal) SetPrompt(prompt string) { + t.lock.Lock() + defer t.lock.Unlock() + + t.prompt = []rune(prompt) +} + +func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) { + // Move cursor to column zero at the start of the line. + t.move(t.cursorY, 0, t.cursorX, 0) + t.cursorX, t.cursorY = 0, 0 + t.clearLineToRight() + for t.cursorY < numPrevLines { + // Move down a line + t.move(0, 1, 0, 0) + t.cursorY++ + t.clearLineToRight() + } + // Move back to beginning. + t.move(t.cursorY, 0, 0, 0) + t.cursorX, t.cursorY = 0, 0 + + t.queue(t.prompt) + t.advanceCursor(visualLength(t.prompt)) + t.writeLine(t.line) + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) SetSize(width, height int) error { + t.lock.Lock() + defer t.lock.Unlock() + + if width == 0 { + width = 1 + } + + oldWidth := t.termWidth + t.termWidth, t.termHeight = width, height + + switch { + case width == oldWidth: + // If the width didn't change then nothing else needs to be + // done. + return nil + case len(t.line) == 0 && t.cursorX == 0 && t.cursorY == 0: + // If there is nothing on current line and no prompt printed, + // just do nothing + return nil + case width < oldWidth: + // Some terminals (e.g. xterm) will truncate lines that were + // too long when shinking. Others, (e.g. gnome-terminal) will + // attempt to wrap them. For the former, repainting t.maxLine + // works great, but that behaviour goes badly wrong in the case + // of the latter because they have doubled every full line. + + // We assume that we are working on a terminal that wraps lines + // and adjust the cursor position based on every previous line + // wrapping and turning into two. This causes the prompt on + // xterms to move upwards, which isn't great, but it avoids a + // huge mess with gnome-terminal. + if t.cursorX >= t.termWidth { + t.cursorX = t.termWidth - 1 + } + t.cursorY *= 2 + t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2) + case width > oldWidth: + // If the terminal expands then our position calculations will + // be wrong in the future because we think the cursor is + // |t.pos| chars into the string, but there will be a gap at + // the end of any wrapped line. + // + // But the position will actually be correct until we move, so + // we can move back to the beginning and repaint everything. + t.clearAndRepaintLinePlusNPrevious(t.maxLine) + } + + _, err := t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + return err +} + +type pasteIndicatorError struct{} + +func (pasteIndicatorError) Error() string { + return "terminal: ErrPasteIndicator not correctly handled" +} + +// ErrPasteIndicator may be returned from ReadLine as the error, in addition +// to valid line data. It indicates that bracketed paste mode is enabled and +// that the returned line consists only of pasted data. Programs may wish to +// interpret pasted data more literally than typed data. +var ErrPasteIndicator = pasteIndicatorError{} + +// SetBracketedPasteMode requests that the terminal bracket paste operations +// with markers. Not all terminals support this but, if it is supported, then +// enabling this mode will stop any autocomplete callback from running due to +// pastes. Additionally, any lines that are completely pasted will be returned +// from ReadLine with the error set to ErrPasteIndicator. +func (t *Terminal) SetBracketedPasteMode(on bool) { + if on { + io.WriteString(t.c, "\x1b[?2004h") + } else { + io.WriteString(t.c, "\x1b[?2004l") + } +} + +// stRingBuffer is a ring buffer of strings. +type stRingBuffer struct { + // entries contains max elements. + entries []string + max int + // head contains the index of the element most recently added to the ring. + head int + // size contains the number of elements in the ring. + size int +} + +func (s *stRingBuffer) Add(a string) { + if s.entries == nil { + const defaultNumEntries = 100 + s.entries = make([]string, defaultNumEntries) + s.max = defaultNumEntries + } + + s.head = (s.head + 1) % s.max + s.entries[s.head] = a + if s.size < s.max { + s.size++ + } +} + +// NthPreviousEntry returns the value passed to the nth previous call to Add. +// If n is zero then the immediately prior value is returned, if one, then the +// next most recent, and so on. If such an element doesn't exist then ok is +// false. +func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) { + if n >= s.size { + return "", false + } + index := s.head - n + if index < 0 { + index += s.max + } + return s.entries[index], true +} diff --git a/vendor/golang.org/x/crypto/ssh/terminal/terminal_test.go b/vendor/golang.org/x/crypto/ssh/terminal/terminal_test.go new file mode 100644 index 0000000..1d54c4f --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/terminal_test.go @@ -0,0 +1,306 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import ( + "bytes" + "io" + "os" + "testing" +) + +type MockTerminal struct { + toSend []byte + bytesPerRead int + received []byte +} + +func (c *MockTerminal) Read(data []byte) (n int, err error) { + n = len(data) + if n == 0 { + return + } + if n > len(c.toSend) { + n = len(c.toSend) + } + if n == 0 { + return 0, io.EOF + } + if c.bytesPerRead > 0 && n > c.bytesPerRead { + n = c.bytesPerRead + } + copy(data, c.toSend[:n]) + c.toSend = c.toSend[n:] + return +} + +func (c *MockTerminal) Write(data []byte) (n int, err error) { + c.received = append(c.received, data...) + return len(data), nil +} + +func TestClose(t *testing.T) { + c := &MockTerminal{} + ss := NewTerminal(c, "> ") + line, err := ss.ReadLine() + if line != "" { + t.Errorf("Expected empty line but got: %s", line) + } + if err != io.EOF { + t.Errorf("Error should have been EOF but got: %s", err) + } +} + +var keyPressTests = []struct { + in string + line string + err error + throwAwayLines int +}{ + { + err: io.EOF, + }, + { + in: "\r", + line: "", + }, + { + in: "foo\r", + line: "foo", + }, + { + in: "a\x1b[Cb\r", // right + line: "ab", + }, + { + in: "a\x1b[Db\r", // left + line: "ba", + }, + { + in: "a\177b\r", // backspace + line: "b", + }, + { + in: "\x1b[A\r", // up + }, + { + in: "\x1b[B\r", // down + }, + { + in: "line\x1b[A\x1b[B\r", // up then down + line: "line", + }, + { + in: "line1\rline2\x1b[A\r", // recall previous line. + line: "line1", + throwAwayLines: 1, + }, + { + // recall two previous lines and append. + in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r", + line: "line1xxx", + throwAwayLines: 2, + }, + { + // Ctrl-A to move to beginning of line followed by ^K to kill + // line. + in: "a b \001\013\r", + line: "", + }, + { + // Ctrl-A to move to beginning of line, Ctrl-E to move to end, + // finally ^K to kill nothing. + in: "a b \001\005\013\r", + line: "a b ", + }, + { + in: "\027\r", + line: "", + }, + { + in: "a\027\r", + line: "", + }, + { + in: "a \027\r", + line: "", + }, + { + in: "a b\027\r", + line: "a ", + }, + { + in: "a b \027\r", + line: "a ", + }, + { + in: "one two thr\x1b[D\027\r", + line: "one two r", + }, + { + in: "\013\r", + line: "", + }, + { + in: "a\013\r", + line: "a", + }, + { + in: "ab\x1b[D\013\r", + line: "a", + }, + { + in: "Ξεσκεπάζω\r", + line: "Ξεσκεπάζω", + }, + { + in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace. + line: "", + throwAwayLines: 1, + }, + { + in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter. + line: "£", + throwAwayLines: 1, + }, + { + // Ctrl-D at the end of the line should be ignored. + in: "a\004\r", + line: "a", + }, + { + // a, b, left, Ctrl-D should erase the b. + in: "ab\x1b[D\004\r", + line: "a", + }, + { + // a, b, c, d, left, left, ^U should erase to the beginning of + // the line. + in: "abcd\x1b[D\x1b[D\025\r", + line: "cd", + }, + { + // Bracketed paste mode: control sequences should be returned + // verbatim in paste mode. + in: "abc\x1b[200~de\177f\x1b[201~\177\r", + line: "abcde\177", + }, + { + // Enter in bracketed paste mode should still work. + in: "abc\x1b[200~d\refg\x1b[201~h\r", + line: "efgh", + throwAwayLines: 1, + }, + { + // Lines consisting entirely of pasted data should be indicated as such. + in: "\x1b[200~a\r", + line: "a", + err: ErrPasteIndicator, + }, +} + +func TestKeyPresses(t *testing.T) { + for i, test := range keyPressTests { + for j := 1; j < len(test.in); j++ { + c := &MockTerminal{ + toSend: []byte(test.in), + bytesPerRead: j, + } + ss := NewTerminal(c, "> ") + for k := 0; k < test.throwAwayLines; k++ { + _, err := ss.ReadLine() + if err != nil { + t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err) + } + } + line, err := ss.ReadLine() + if line != test.line { + t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) + break + } + if err != test.err { + t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err) + break + } + } + } +} + +func TestPasswordNotSaved(t *testing.T) { + c := &MockTerminal{ + toSend: []byte("password\r\x1b[A\r"), + bytesPerRead: 1, + } + ss := NewTerminal(c, "> ") + pw, _ := ss.ReadPassword("> ") + if pw != "password" { + t.Fatalf("failed to read password, got %s", pw) + } + line, _ := ss.ReadLine() + if len(line) > 0 { + t.Fatalf("password was saved in history") + } +} + +var setSizeTests = []struct { + width, height int +}{ + {40, 13}, + {80, 24}, + {132, 43}, +} + +func TestTerminalSetSize(t *testing.T) { + for _, setSize := range setSizeTests { + c := &MockTerminal{ + toSend: []byte("password\r\x1b[A\r"), + bytesPerRead: 1, + } + ss := NewTerminal(c, "> ") + ss.SetSize(setSize.width, setSize.height) + pw, _ := ss.ReadPassword("Password: ") + if pw != "password" { + t.Fatalf("failed to read password, got %s", pw) + } + if string(c.received) != "Password: \r\n" { + t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received) + } + } +} + +func TestMakeRawState(t *testing.T) { + fd := int(os.Stdout.Fd()) + if !IsTerminal(fd) { + t.Skip("stdout is not a terminal; skipping test") + } + + st, err := GetState(fd) + if err != nil { + t.Fatalf("failed to get terminal state from GetState: %s", err) + } + defer Restore(fd, st) + raw, err := MakeRaw(fd) + if err != nil { + t.Fatalf("failed to get terminal state from MakeRaw: %s", err) + } + + if *st != *raw { + t.Errorf("states do not match; was %v, expected %v", raw, st) + } +} + +func TestOutputNewlines(t *testing.T) { + // \n should be changed to \r\n in terminal output. + buf := new(bytes.Buffer) + term := NewTerminal(buf, ">") + + term.Write([]byte("1\n2\n")) + output := string(buf.Bytes()) + const expected = "1\r\n2\r\n" + + if output != expected { + t.Errorf("incorrect output: was %q, expected %q", output, expected) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/terminal/util.go b/vendor/golang.org/x/crypto/ssh/terminal/util.go new file mode 100644 index 0000000..c869213 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/util.go @@ -0,0 +1,133 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal // import "golang.org/x/crypto/ssh/terminal" + +import ( + "io" + "syscall" + "unsafe" +) + +// State contains the state of a terminal. +type State struct { + termios syscall.Termios +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var termios syscall.Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var oldState State + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { + return nil, err + } + + newState := oldState.termios + // This attempts to replicate the behaviour documented for cfmakeraw in + // the termios(3) manpage. + newState.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON + newState.Oflag &^= syscall.OPOST + newState.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN + newState.Cflag &^= syscall.CSIZE | syscall.PARENB + newState.Cflag |= syscall.CS8 + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { + return nil, err + } + + return &oldState, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var oldState State + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { + return nil, err + } + + return &oldState, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0) + return err +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + var dimensions [4]uint16 + + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 { + return -1, -1, err + } + return int(dimensions[1]), int(dimensions[0]), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var oldState syscall.Termios + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { + return nil, err + } + + newState := oldState + newState.Lflag &^= syscall.ECHO + newState.Lflag |= syscall.ICANON | syscall.ISIG + newState.Iflag |= syscall.ICRNL + if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { + return nil, err + } + + defer func() { + syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) + }() + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(fd, buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/terminal/util_bsd.go b/vendor/golang.org/x/crypto/ssh/terminal/util_bsd.go new file mode 100644 index 0000000..9c1ffd1 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/util_bsd.go @@ -0,0 +1,12 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd netbsd openbsd + +package terminal + +import "syscall" + +const ioctlReadTermios = syscall.TIOCGETA +const ioctlWriteTermios = syscall.TIOCSETA diff --git a/vendor/golang.org/x/crypto/ssh/terminal/util_linux.go b/vendor/golang.org/x/crypto/ssh/terminal/util_linux.go new file mode 100644 index 0000000..5883b22 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/util_linux.go @@ -0,0 +1,11 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +// These constants are declared here, rather than importing +// them from the syscall package as some syscall packages, even +// on linux, for example gccgo, do not declare them. +const ioctlReadTermios = 0x5401 // syscall.TCGETS +const ioctlWriteTermios = 0x5402 // syscall.TCSETS diff --git a/vendor/golang.org/x/crypto/ssh/terminal/util_plan9.go b/vendor/golang.org/x/crypto/ssh/terminal/util_plan9.go new file mode 100644 index 0000000..799f049 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/util_plan9.go @@ -0,0 +1,58 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "fmt" + "runtime" +) + +type State struct{} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + return false +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/vendor/golang.org/x/crypto/ssh/terminal/util_solaris.go b/vendor/golang.org/x/crypto/ssh/terminal/util_solaris.go new file mode 100644 index 0000000..07eb5ed --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/util_solaris.go @@ -0,0 +1,73 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build solaris + +package terminal // import "golang.org/x/crypto/ssh/terminal" + +import ( + "golang.org/x/sys/unix" + "io" + "syscall" +) + +// State contains the state of a terminal. +type State struct { + termios syscall.Termios +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + // see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c + var termio unix.Termio + err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio) + return err == nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + // see also: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c + val, err := unix.IoctlGetTermios(fd, unix.TCGETS) + if err != nil { + return nil, err + } + oldState := *val + + newState := oldState + newState.Lflag &^= syscall.ECHO + newState.Lflag |= syscall.ICANON | syscall.ISIG + newState.Iflag |= syscall.ICRNL + err = unix.IoctlSetTermios(fd, unix.TCSETS, &newState) + if err != nil { + return nil, err + } + + defer unix.IoctlSetTermios(fd, unix.TCSETS, &oldState) + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(fd, buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/terminal/util_windows.go b/vendor/golang.org/x/crypto/ssh/terminal/util_windows.go new file mode 100644 index 0000000..ae9fa9e --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/terminal/util_windows.go @@ -0,0 +1,174 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "io" + "syscall" + "unsafe" +) + +const ( + enableLineInput = 2 + enableEchoInput = 4 + enableProcessedInput = 1 + enableWindowInput = 8 + enableMouseInput = 16 + enableInsertMode = 32 + enableQuickEditMode = 64 + enableExtendedFlags = 128 + enableAutoPosition = 256 + enableProcessedOutput = 1 + enableWrapAtEolOutput = 2 +) + +var kernel32 = syscall.NewLazyDLL("kernel32.dll") + +var ( + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + +type ( + short int16 + word uint16 + + coord struct { + x short + y short + } + smallRect struct { + left short + top short + right short + bottom short + } + consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes word + window smallRect + maximumWindowSize coord + } +) + +type State struct { + mode uint32 +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var st uint32 + r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + return r != 0 && e == 0 +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + raw := st &^ (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(raw), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + return &State{st}, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + _, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) + return err +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + var info consoleScreenBufferInfo + _, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) + if e != 0 { + return 0, 0, error(e) + } + return int(info.size.x), int(info.size.y), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var st uint32 + _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + if e != 0 { + return nil, error(e) + } + old := st + + st &^= (enableEchoInput) + st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) + _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) + if e != 0 { + return nil, error(e) + } + + defer func() { + syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) + }() + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(syscall.Handle(fd), buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + if n > 0 && buf[n-1] == '\r' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/test/agent_unix_test.go b/vendor/golang.org/x/crypto/ssh/test/agent_unix_test.go new file mode 100644 index 0000000..f481253 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/agent_unix_test.go @@ -0,0 +1,59 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "bytes" + "testing" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func TestAgentForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + keyring := agent.NewKeyring() + if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil { + t.Fatalf("Error adding key: %s", err) + } + if err := keyring.Add(agent.AddedKey{ + PrivateKey: testPrivateKeys["dsa"], + ConfirmBeforeUse: true, + LifetimeSecs: 3600, + }); err != nil { + t.Fatalf("Error adding key with constraints: %s", err) + } + pub := testPublicKeys["dsa"] + + sess, err := conn.NewSession() + if err != nil { + t.Fatalf("NewSession: %v", err) + } + if err := agent.RequestAgentForwarding(sess); err != nil { + t.Fatalf("RequestAgentForwarding: %v", err) + } + + if err := agent.ForwardToAgent(conn, keyring); err != nil { + t.Fatalf("SetupForwardKeyring: %v", err) + } + out, err := sess.CombinedOutput("ssh-add -L") + if err != nil { + t.Fatalf("running ssh-add: %v, out %s", err, out) + } + key, _, _, _, err := ssh.ParseAuthorizedKey(out) + if err != nil { + t.Fatalf("ParseAuthorizedKey(%q): %v", out, err) + } + + if !bytes.Equal(key.Marshal(), pub.Marshal()) { + t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub)) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/test/cert_test.go b/vendor/golang.org/x/crypto/ssh/test/cert_test.go new file mode 100644 index 0000000..364790f --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/cert_test.go @@ -0,0 +1,47 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "crypto/rand" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestCertLogin(t *testing.T) { + s := newServer(t) + defer s.Shutdown() + + // Use a key different from the default. + clientKey := testSigners["dsa"] + caAuthKey := testSigners["ecdsa"] + cert := &ssh.Certificate{ + Key: clientKey.PublicKey(), + ValidPrincipals: []string{username()}, + CertType: ssh.UserCert, + ValidBefore: ssh.CertTimeInfinity, + } + if err := cert.SignCert(rand.Reader, caAuthKey); err != nil { + t.Fatalf("SetSignature: %v", err) + } + + certSigner, err := ssh.NewCertSigner(cert, clientKey) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + conf := &ssh.ClientConfig{ + User: username(), + } + conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) + client, err := s.TryDial(conf) + if err != nil { + t.Fatalf("TryDial: %v", err) + } + client.Close() +} diff --git a/vendor/golang.org/x/crypto/ssh/test/doc.go b/vendor/golang.org/x/crypto/ssh/test/doc.go new file mode 100644 index 0000000..3f9b334 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/doc.go @@ -0,0 +1,7 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package contains integration tests for the +// golang.org/x/crypto/ssh package. +package test // import "golang.org/x/crypto/ssh/test" diff --git a/vendor/golang.org/x/crypto/ssh/test/forward_unix_test.go b/vendor/golang.org/x/crypto/ssh/test/forward_unix_test.go new file mode 100644 index 0000000..877a88c --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/forward_unix_test.go @@ -0,0 +1,160 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package test + +import ( + "bytes" + "io" + "io/ioutil" + "math/rand" + "net" + "testing" + "time" +) + +func TestPortForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + go func() { + sshConn, err := sshListener.Accept() + if err != nil { + t.Fatalf("listen.Accept failed: %v", err) + } + + _, err = io.Copy(sshConn, sshConn) + if err != nil && err != io.EOF { + t.Fatalf("ssh client copy: %v", err) + } + sshConn.Close() + }() + + forwardedAddr := sshListener.Addr().String() + tcpConn, err := net.Dial("tcp", forwardedAddr) + if err != nil { + t.Fatalf("TCP dial failed: %v", err) + } + + readChan := make(chan []byte) + go func() { + data, _ := ioutil.ReadAll(tcpConn) + readChan <- data + }() + + // Invent some data. + data := make([]byte, 100*1000) + for i := range data { + data[i] = byte(i % 255) + } + + var sent []byte + for len(sent) < 1000*1000 { + // Send random sized chunks + m := rand.Intn(len(data)) + n, err := tcpConn.Write(data[:m]) + if err != nil { + break + } + sent = append(sent, data[:n]...) + } + if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { + t.Errorf("tcpConn.CloseWrite: %v", err) + } + + read := <-readChan + + if len(sent) != len(read) { + t.Fatalf("got %d bytes, want %d", len(read), len(sent)) + } + if bytes.Compare(sent, read) != 0 { + t.Fatalf("read back data does not match") + } + + if err := sshListener.Close(); err != nil { + t.Fatalf("sshListener.Close: %v", err) + } + + // Check that the forward disappeared. + tcpConn, err = net.Dial("tcp", forwardedAddr) + if err == nil { + tcpConn.Close() + t.Errorf("still listening to %s after closing", forwardedAddr) + } +} + +func TestAcceptClose(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + quit := make(chan error, 1) + go func() { + for { + c, err := sshListener.Accept() + if err != nil { + quit <- err + break + } + c.Close() + } + }() + sshListener.Close() + + select { + case <-time.After(1 * time.Second): + t.Errorf("timeout: listener did not close.") + case err := <-quit: + t.Logf("quit as expected (error %v)", err) + } +} + +// Check that listeners exit if the underlying client transport dies. +func TestPortForwardConnectionClose(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + + sshListener, err := conn.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + quit := make(chan error, 1) + go func() { + for { + c, err := sshListener.Accept() + if err != nil { + quit <- err + break + } + c.Close() + } + }() + + // It would be even nicer if we closed the server side, but it + // is more involved as the fd for that side is dup()ed. + server.clientConn.Close() + + select { + case <-time.After(1 * time.Second): + t.Errorf("timeout: listener did not close.") + case err := <-quit: + t.Logf("quit as expected (error %v)", err) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/test/session_test.go b/vendor/golang.org/x/crypto/ssh/test/session_test.go new file mode 100644 index 0000000..fc7e471 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/session_test.go @@ -0,0 +1,365 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package test + +// Session functional tests. + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestRunCommandSuccess(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + err = session.Run("true") + if err != nil { + t.Fatalf("session failed: %v", err) + } +} + +func TestHostKeyCheck(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + + conf := clientConfig() + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check + + // change the keys. + hostDB.keys[ssh.KeyAlgoRSA][25]++ + hostDB.keys[ssh.KeyAlgoDSA][25]++ + hostDB.keys[ssh.KeyAlgoECDSA256][25]++ + + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + t.Fatalf("dial should have failed.") + } else if !strings.Contains(err.Error(), "host key mismatch") { + t.Fatalf("'host key mismatch' not found in %v", err) + } +} + +func TestRunCommandStdin(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + defer w.Close() + session.Stdin = r + + err = session.Run("true") + if err != nil { + t.Fatalf("session failed: %v", err) + } +} + +func TestRunCommandStdinError(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + r, w := io.Pipe() + defer r.Close() + session.Stdin = r + pipeErr := errors.New("closing write end of pipe") + w.CloseWithError(pipeErr) + + err = session.Run("true") + if err != pipeErr { + t.Fatalf("expected %v, found %v", pipeErr, err) + } +} + +func TestRunCommandFailed(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + err = session.Run(`bash -c "kill -9 $$"`) + if err == nil { + t.Fatalf("session succeeded: %v", err) + } +} + +func TestRunCommandWeClosed(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + err = session.Shell() + if err != nil { + t.Fatalf("shell failed: %v", err) + } + err = session.Close() + if err != nil { + t.Fatalf("shell failed: %v", err) + } +} + +func TestFuncLargeRead(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("unable to create new session: %s", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + err = session.Start("dd if=/dev/urandom bs=2048 count=1024") + if err != nil { + t.Fatalf("unable to execute remote command: %s", err) + } + + buf := new(bytes.Buffer) + n, err := io.Copy(buf, stdout) + if err != nil { + t.Fatalf("error reading from remote stdout: %s", err) + } + + if n != 2048*1024 { + t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n) + } +} + +func TestKeyChange(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + hostDB := hostKeyDB() + conf.HostKeyCallback = hostDB.Check + conf.RekeyThreshold = 1024 + conn := server.Dial(conf) + defer conn.Close() + + for i := 0; i < 4; i++ { + session, err := conn.NewSession() + if err != nil { + t.Fatalf("unable to create new session: %s", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + err = session.Start("dd if=/dev/urandom bs=1024 count=1") + if err != nil { + t.Fatalf("unable to execute remote command: %s", err) + } + buf := new(bytes.Buffer) + n, err := io.Copy(buf, stdout) + if err != nil { + t.Fatalf("error reading from remote stdout: %s", err) + } + + want := int64(1024) + if n != want { + t.Fatalf("Expected %d bytes but read only %d from remote command", want, n) + } + } + + if changes := hostDB.checkCount; changes < 4 { + t.Errorf("got %d key changes, want 4", changes) + } +} + +func TestInvalidTerminalMode(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + if err = session.RequestPty("vt100", 80, 40, ssh.TerminalModes{255: 1984}); err == nil { + t.Fatalf("req-pty failed: successful request with invalid mode") + } +} + +func TestValidTerminalMode(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatalf("session failed: %v", err) + } + defer session.Close() + + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("unable to acquire stdout pipe: %s", err) + } + + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("unable to acquire stdin pipe: %s", err) + } + + tm := ssh.TerminalModes{ssh.ECHO: 0} + if err = session.RequestPty("xterm", 80, 40, tm); err != nil { + t.Fatalf("req-pty failed: %s", err) + } + + err = session.Shell() + if err != nil { + t.Fatalf("session failed: %s", err) + } + + stdin.Write([]byte("stty -a && exit\n")) + + var buf bytes.Buffer + if _, err := io.Copy(&buf, stdout); err != nil { + t.Fatalf("reading failed: %s", err) + } + + if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") { + t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput) + } +} + +func TestCiphers(t *testing.T) { + var config ssh.Config + config.SetDefaults() + cipherOrder := config.Ciphers + // These ciphers will not be tested when commented out in cipher.go it will + // fallback to the next available as per line 292. + cipherOrder = append(cipherOrder, "aes128-cbc", "3des-cbc") + + for _, ciph := range cipherOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.Ciphers = []string{ciph} + // Don't fail if sshd doesn't have the cipher. + conf.Ciphers = append(conf.Ciphers, cipherOrder...) + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Fatalf("failed for cipher %q", ciph) + } + } +} + +func TestMACs(t *testing.T) { + var config ssh.Config + config.SetDefaults() + macOrder := config.MACs + + for _, mac := range macOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + conf.MACs = []string{mac} + // Don't fail if sshd doesn't have the MAC. + conf.MACs = append(conf.MACs, macOrder...) + if conn, err := server.TryDial(conf); err == nil { + conn.Close() + } else { + t.Fatalf("failed for MAC %q", mac) + } + } +} + +func TestKeyExchanges(t *testing.T) { + var config ssh.Config + config.SetDefaults() + kexOrder := config.KeyExchanges + for _, kex := range kexOrder { + server := newServer(t) + defer server.Shutdown() + conf := clientConfig() + // Don't fail if sshd doesn't have the kex. + conf.KeyExchanges = append([]string{kex}, kexOrder...) + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Errorf("failed for kex %q", kex) + } + } +} + +func TestClientAuthAlgorithms(t *testing.T) { + for _, key := range []string{ + "rsa", + "dsa", + "ecdsa", + "ed25519", + } { + server := newServer(t) + conf := clientConfig() + conf.SetDefaults() + conf.Auth = []ssh.AuthMethod{ + ssh.PublicKeys(testSigners[key]), + } + + conn, err := server.TryDial(conf) + if err == nil { + conn.Close() + } else { + t.Errorf("failed for key %q", key) + } + + server.Shutdown() + } +} diff --git a/vendor/golang.org/x/crypto/ssh/test/tcpip_test.go b/vendor/golang.org/x/crypto/ssh/test/tcpip_test.go new file mode 100644 index 0000000..a2eb935 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/tcpip_test.go @@ -0,0 +1,46 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package test + +// direct-tcpip functional tests + +import ( + "io" + "net" + "testing" +) + +func TestDial(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + sshConn := server.Dial(clientConfig()) + defer sshConn.Close() + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer l.Close() + + go func() { + for { + c, err := l.Accept() + if err != nil { + break + } + + io.WriteString(c, c.RemoteAddr().String()) + c.Close() + } + }() + + conn, err := sshConn.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() +} diff --git a/vendor/golang.org/x/crypto/ssh/test/test_unix_test.go b/vendor/golang.org/x/crypto/ssh/test/test_unix_test.go new file mode 100644 index 0000000..3bfd881 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/test_unix_test.go @@ -0,0 +1,268 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd plan9 + +package test + +// functional test harness for unix. + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "os/user" + "path/filepath" + "testing" + "text/template" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +const sshd_config = ` +Protocol 2 +HostKey {{.Dir}}/id_rsa +HostKey {{.Dir}}/id_dsa +HostKey {{.Dir}}/id_ecdsa +Pidfile {{.Dir}}/sshd.pid +#UsePrivilegeSeparation no +KeyRegenerationInterval 3600 +ServerKeyBits 768 +SyslogFacility AUTH +LogLevel DEBUG2 +LoginGraceTime 120 +PermitRootLogin no +StrictModes no +RSAAuthentication yes +PubkeyAuthentication yes +AuthorizedKeysFile {{.Dir}}/authorized_keys +TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub +IgnoreRhosts yes +RhostsRSAAuthentication no +HostbasedAuthentication no +PubkeyAcceptedKeyTypes=* +` + +var configTmpl = template.Must(template.New("").Parse(sshd_config)) + +type server struct { + t *testing.T + cleanup func() // executed during Shutdown + configfile string + cmd *exec.Cmd + output bytes.Buffer // holds stderr from sshd process + + // Client half of the network connection. + clientConn net.Conn +} + +func username() string { + var username string + if user, err := user.Current(); err == nil { + username = user.Username + } else { + // user.Current() currently requires cgo. If an error is + // returned attempt to get the username from the environment. + log.Printf("user.Current: %v; falling back on $USER", err) + username = os.Getenv("USER") + } + if username == "" { + panic("Unable to get username") + } + return username +} + +type storedHostKey struct { + // keys map from an algorithm string to binary key data. + keys map[string][]byte + + // checkCount counts the Check calls. Used for testing + // rekeying. + checkCount int +} + +func (k *storedHostKey) Add(key ssh.PublicKey) { + if k.keys == nil { + k.keys = map[string][]byte{} + } + k.keys[key.Type()] = key.Marshal() +} + +func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error { + k.checkCount++ + algo := key.Type() + + if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 { + return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo]) + } + return nil +} + +func hostKeyDB() *storedHostKey { + keyChecker := &storedHostKey{} + keyChecker.Add(testPublicKeys["ecdsa"]) + keyChecker.Add(testPublicKeys["rsa"]) + keyChecker.Add(testPublicKeys["dsa"]) + return keyChecker +} + +func clientConfig() *ssh.ClientConfig { + config := &ssh.ClientConfig{ + User: username(), + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(testSigners["user"]), + }, + HostKeyCallback: hostKeyDB().Check, + } + return config +} + +// unixConnection creates two halves of a connected net.UnixConn. It +// is used for connecting the Go SSH client with sshd without opening +// ports. +func unixConnection() (*net.UnixConn, *net.UnixConn, error) { + dir, err := ioutil.TempDir("", "unixConnection") + if err != nil { + return nil, nil, err + } + defer os.Remove(dir) + + addr := filepath.Join(dir, "ssh") + listener, err := net.Listen("unix", addr) + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("unix", addr) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1.(*net.UnixConn), c2.(*net.UnixConn), nil +} + +func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { + sshd, err := exec.LookPath("sshd") + if err != nil { + s.t.Skipf("skipping test: %v", err) + } + + c1, c2, err := unixConnection() + if err != nil { + s.t.Fatalf("unixConnection: %v", err) + } + + s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") + f, err := c2.File() + if err != nil { + s.t.Fatalf("UnixConn.File: %v", err) + } + defer f.Close() + s.cmd.Stdin = f + s.cmd.Stdout = f + s.cmd.Stderr = &s.output + if err := s.cmd.Start(); err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("s.cmd.Start: %v", err) + } + s.clientConn = c1 + conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) + if err != nil { + return nil, err + } + return ssh.NewClient(conn, chans, reqs), nil +} + +func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client { + conn, err := s.TryDial(config) + if err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("ssh.Client: %v", err) + } + return conn +} + +func (s *server) Shutdown() { + if s.cmd != nil && s.cmd.Process != nil { + // Don't check for errors; if it fails it's most + // likely "os: process already finished", and we don't + // care about that. Use os.Interrupt, so child + // processes are killed too. + s.cmd.Process.Signal(os.Interrupt) + s.cmd.Wait() + } + if s.t.Failed() { + // log any output from sshd process + s.t.Logf("sshd: %s", s.output.String()) + } + s.cleanup() +} + +func writeFile(path string, contents []byte) { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) + if err != nil { + panic(err) + } + defer f.Close() + if _, err := f.Write(contents); err != nil { + panic(err) + } +} + +// newServer returns a new mock ssh server. +func newServer(t *testing.T) *server { + if testing.Short() { + t.Skip("skipping test due to -short") + } + dir, err := ioutil.TempDir("", "sshtest") + if err != nil { + t.Fatal(err) + } + f, err := os.Create(filepath.Join(dir, "sshd_config")) + if err != nil { + t.Fatal(err) + } + err = configTmpl.Execute(f, map[string]string{ + "Dir": dir, + }) + if err != nil { + t.Fatal(err) + } + f.Close() + + for k, v := range testdata.PEMBytes { + filename := "id_" + k + writeFile(filepath.Join(dir, filename), v) + writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) + } + + var authkeys bytes.Buffer + for k, _ := range testdata.PEMBytes { + authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) + } + writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) + + return &server{ + t: t, + configfile: f.Name(), + cleanup: func() { + if err := os.RemoveAll(dir); err != nil { + t.Error(err) + } + }, + } +} diff --git a/vendor/golang.org/x/crypto/ssh/test/testdata_test.go b/vendor/golang.org/x/crypto/ssh/test/testdata_test.go new file mode 100644 index 0000000..a053f67 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/testdata_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package test + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]ssh.Signer + testPublicKeys map[string]ssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]ssh.Signer, n) + testPublicKeys = make(map[string]ssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &ssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/testdata/doc.go b/vendor/golang.org/x/crypto/ssh/testdata/doc.go new file mode 100644 index 0000000..fcae47c --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/testdata/doc.go @@ -0,0 +1,8 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package contains test data shared between the various subpackages of +// the golang.org/x/crypto/ssh package. Under no circumstance should +// this data be used for production code. +package testdata // import "golang.org/x/crypto/ssh/testdata" diff --git a/vendor/golang.org/x/crypto/ssh/testdata/keys.go b/vendor/golang.org/x/crypto/ssh/testdata/keys.go new file mode 100644 index 0000000..736dad9 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/testdata/keys.go @@ -0,0 +1,120 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testdata + +var PEMBytes = map[string][]byte{ + "dsa": []byte(`-----BEGIN DSA PRIVATE KEY----- +MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB +lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3 +EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD +nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV +2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r +juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr +FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz +DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj +nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY +Fmsr0W6fHB9nhS4/UXM8 +-----END DSA PRIVATE KEY----- +`), + "ecdsa": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49 +AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ +6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== +-----END EC PRIVATE KEY----- +`), + "rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2 +a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8 +Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQIDAQAB +AoGAJMCk5vqfSRzyXOTXLGIYCuR4Kj6pdsbNSeuuRGfYBeR1F2c/XdFAg7D/8s5R +38p/Ih52/Ty5S8BfJtwtvgVY9ecf/JlU/rl/QzhG8/8KC0NG7KsyXklbQ7gJT8UT +Ojmw5QpMk+rKv17ipDVkQQmPaj+gJXYNAHqImke5mm/K/h0CQQDciPmviQ+DOhOq +2ZBqUfH8oXHgFmp7/6pXw80DpMIxgV3CwkxxIVx6a8lVH9bT/AFySJ6vXq4zTuV9 +6QmZcZzDAkEA2j/UXJPIs1fQ8z/6sONOkU/BjtoePFIWJlRxdN35cZjXnBraX5UR +fFHkePv4YwqmXNqrBOvSu+w2WdSDci+IKwJAcsPRc/jWmsrJW1q3Ha0hSf/WG/Bu +X7MPuXaKpP/DkzGoUmb8ks7yqj6XWnYkPNLjCc8izU5vRwIiyWBRf4mxMwJBAILa +NDvRS0rjwt6lJGv7zPZoqDc65VfrK2aNyHx2PgFyzwrEOtuF57bu7pnvEIxpLTeM +z26i6XVMeYXAWZMTloMCQBbpGgEERQpeUknLBqUHhg/wXF6+lFA+vEGnkY+Dwab2 +KCXFGd+SQ5GdUcEMe9isUH6DYj/6/yCDoFrXXmpQb+M= +-----END RSA PRIVATE KEY----- +`), + "ed25519": []byte(`-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACA+3f7hS7g5UWwXOGVTrMfhmxyrjqz7Sxxbx7I1j8DvvwAAAJhAFfkOQBX5 +DgAAAAtzc2gtZWQyNTUxOQAAACA+3f7hS7g5UWwXOGVTrMfhmxyrjqz7Sxxbx7I1j8Dvvw +AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb +HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg== +-----END OPENSSH PRIVATE KEY----- +`), + "user": []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 +AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD +PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== +-----END EC PRIVATE KEY----- +`), +} + +var PEMEncryptedKeys = []struct { + Name string + EncryptionKey string + PEMBytes []byte +}{ + 0: { + Name: "rsa-encrypted", + EncryptionKey: "r54-G0pher_t3st$", + PEMBytes: []byte(`-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-128-CBC,3E1714DE130BC5E81327F36564B05462 + +MqW88sud4fnWk/Jk3fkjh7ydu51ZkHLN5qlQgA4SkAXORPPMj2XvqZOv1v2LOgUV +dUevUn8PZK7a9zbZg4QShUSzwE5k6wdB7XKPyBgI39mJ79GBd2U4W3h6KT6jIdWA +goQpluxkrzr2/X602IaxLEre97FT9mpKC6zxKCLvyFWVIP9n3OSFS47cTTXyFr+l +7PdRhe60nn6jSBgUNk/Q1lAvEQ9fufdPwDYY93F1wyJ6lOr0F1+mzRrMbH67NyKs +rG8J1Fa7cIIre7ueKIAXTIne7OAWqpU9UDgQatDtZTbvA7ciqGsSFgiwwW13N+Rr +hN8MkODKs9cjtONxSKi05s206A3NDU6STtZ3KuPDjFE1gMJODotOuqSM+cxKfyFq +wxpk/CHYCDdMAVBSwxb/vraOHamylL4uCHpJdBHypzf2HABt+lS8Su23uAmL87DR +yvyCS/lmpuNTndef6qHPRkoW2EV3xqD3ovosGf7kgwGJUk2ZpCLVteqmYehKlZDK +r/Jy+J26ooI2jIg9bjvD1PZq+Mv+2dQ1RlDrPG3PB+rEixw6vBaL9x3jatCd4ej7 +XG7lb3qO9xFpLsx89tkEcvpGR+broSpUJ6Mu5LBCVmrvqHjvnDhrZVz1brMiQtU9 +iMZbgXqDLXHd6ERWygk7OTU03u+l1gs+KGMfmS0h0ZYw6KGVLgMnsoxqd6cFSKNB +8Ohk9ZTZGCiovlXBUepyu8wKat1k8YlHSfIHoRUJRhhcd7DrmojC+bcbMIZBU22T +Pl2ftVRGtcQY23lYd0NNKfebF7ncjuLWQGy+vZW+7cgfI6wPIbfYfP6g7QAutk6W +KQx0AoX5woZ6cNxtpIrymaVjSMRRBkKQrJKmRp3pC/lul5E5P2cueMs1fj4OHTbJ +lAUv88ywr+R+mRgYQlFW/XQ653f6DT4t6+njfO9oBcPrQDASZel3LjXLpjjYG/N5 ++BWnVexuJX9ika8HJiFl55oqaKb+WknfNhk5cPY+x7SDV9ywQeMiDZpr0ffeYAEP +LlwwiWRDYpO+uwXHSFF3+JjWwjhs8m8g99iFb7U93yKgBB12dCEPPa2ZeH9wUHMJ +sreYhNuq6f4iWWSXpzN45inQqtTi8jrJhuNLTT543ErW7DtntBO2rWMhff3aiXbn +Uy3qzZM1nPbuCGuBmP9L2dJ3Z5ifDWB4JmOyWY4swTZGt9AVmUxMIKdZpRONx8vz +I9u9nbVPGZBcou50Pa0qTLbkWsSL94MNXrARBxzhHC9Zs6XNEtwN7mOuii7uMkVc +adrxgknBH1J1N+NX/eTKzUwJuPvDtA+Z5ILWNN9wpZT/7ed8zEnKHPNUexyeT5g3 +uw9z9jH7ffGxFYlx87oiVPHGOrCXYZYW5uoZE31SCBkbtNuffNRJRKIFeipmpJ3P +7bpAG+kGHMelQH6b+5K1Qgsv4tpuSyKeTKpPFH9Av5nN4P1ZBm9N80tzbNWqjSJm +S7rYdHnuNEVnUGnRmEUMmVuYZnNBEVN/fP2m2SEwXcP3Uh7TiYlcWw10ygaGmOr7 +MvMLGkYgQ4Utwnd98mtqa0jr0hK2TcOSFir3AqVvXN3XJj4cVULkrXe4Im1laWgp +-----END RSA PRIVATE KEY----- +`), + }, + + 1: { + Name: "dsa-encrypted", + EncryptionKey: "qG0pher-dsa_t3st$", + PEMBytes: []byte(`-----BEGIN DSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-128-CBC,7CE7A6E4A647DC01AF860210B15ADE3E + +hvnBpI99Hceq/55pYRdOzBLntIEis02JFNXuLEydWL+RJBFDn7tA+vXec0ERJd6J +G8JXlSOAhmC2H4uK3q2xR8/Y3yL95n6OIcjvCBiLsV+o3jj1MYJmErxP6zRtq4w3 +JjIjGHWmaYFSxPKQ6e8fs74HEqaeMV9ONUoTtB+aISmgaBL15Fcoayg245dkBvVl +h5Kqspe7yvOBmzA3zjRuxmSCqKJmasXM7mqs3vIrMxZE3XPo1/fWKcPuExgpVQoT +HkJZEoIEIIPnPMwT2uYbFJSGgPJVMDT84xz7yvjCdhLmqrsXgs5Qw7Pw0i0c0BUJ +b7fDJ2UhdiwSckWGmIhTLlJZzr8K+JpjCDlP+REYBI5meB7kosBnlvCEHdw2EJkH +0QDc/2F4xlVrHOLbPRFyu1Oi2Gvbeoo9EsM/DThpd1hKAlb0sF5Y0y0d+owv0PnE +R/4X3HWfIdOHsDUvJ8xVWZ4BZk9Zk9qol045DcFCehpr/3hslCrKSZHakLt9GI58 +vVQJ4L0aYp5nloLfzhViZtKJXRLkySMKdzYkIlNmW1oVGl7tce5UCNI8Nok4j6yn +IiHM7GBn+0nJoKTXsOGMIBe3ulKlKVxLjEuk9yivh/8= +-----END DSA PRIVATE KEY----- +`), + }, +} diff --git a/vendor/golang.org/x/crypto/ssh/testdata_test.go b/vendor/golang.org/x/crypto/ssh/testdata_test.go new file mode 100644 index 0000000..2da8c79 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/testdata_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package ssh + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]Signer + testPublicKeys map[string]PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]Signer, n) + testPublicKeys = make(map[string]PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/vendor/golang.org/x/crypto/ssh/transport.go b/vendor/golang.org/x/crypto/ssh/transport.go new file mode 100644 index 0000000..62fba62 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/transport.go @@ -0,0 +1,333 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "errors" + "io" +) + +const ( + gcmCipherID = "aes128-gcm@openssh.com" + aes128cbcID = "aes128-cbc" + tripledescbcID = "3des-cbc" +) + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// transport is the keyingTransport that implements the SSH packet +// protocol. +type transport struct { + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader + + io.Closer +} + +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writePacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher +} + +// prepareKeyChange sets up key material for a keychange. The key changes in +// both directions are triggered by reading and writing a msgNewKey packet +// respectively. +func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { + if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { + return err + } else { + t.reader.pendingKeyChange <- ciph + } + + if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { + return err + } else { + t.writer.pendingKeyChange <- ciph + } + + return nil +} + +// Read and decrypt next packet. +func (t *transport) readPacket() ([]byte, error) { + return t.reader.readPacket(t.bufReader) +} + +func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { + packet, err := s.packetCipher.readPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") + } + + if len(packet) > 0 { + switch packet[0] { + case msgNewKeys: + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + return nil, errors.New("ssh: got bogus newkeys message.") + } + + case msgDisconnect: + // Transform a disconnect message into an + // error. Since this is lowest level at which + // we interpret message types, doing it here + // ensures that we don't have to handle it + // elsewhere. + var msg disconnectMsg + if err := Unmarshal(packet, &msg); err != nil { + return nil, err + } + return nil, &msg + } + } + + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) + + return fresh, err +} + +func (t *transport) writePacket(packet []byte) error { + return t.writer.writePacket(t.bufWriter, t.rand, packet) +} + +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys + + err := s.packetCipher.writePacket(s.seqNum, w, rand, packet) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + s.seqNum++ + if changeKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + panic("ssh: no key material for msgNewKeys") + } + } + return err +} + +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { + t := &transport{ + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + Closer: rwc, + } + if isClient { + t.reader.dir = serverKeys + t.writer.dir = clientKeys + } else { + t.reader.dir = clientKeys + t.writer.dir = serverKeys + } + + return t +} + +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} +) + +// generateKeys generates key material for IV, MAC and encryption. +func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) { + cipherMode := cipherModes[algs.Cipher] + macMode := macModes[algs.MAC] + + iv = make([]byte, cipherMode.ivSize) + key = make([]byte, cipherMode.keySize) + macKey = make([]byte, macMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + generateKeyMaterial(macKey, d.macKeyTag, kex) + return +} + +// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + iv, key, macKey := generateKeys(d, algs, kex) + + if algs.Cipher == gcmCipherID { + return newGCMCipher(iv, key, macKey) + } + + if algs.Cipher == aes128cbcID { + return newAESCBCCipher(iv, key, macKey, algs) + } + + if algs.Cipher == tripledescbcID { + return newTripleDESCBCCipher(iv, key, macKey, algs) + } + + c := &streamPacketCipher{ + mac: macModes[algs.MAC].new(macKey), + } + c.macResult = make([]byte, c.mac.Size()) + + var err error + c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv) + if err != nil { + return nil, err + } + + return c, nil +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out, tag []byte, r *kexResult) { + var digestsSoFar []byte + + h := r.Hash.New() + for len(out) > 0 { + h.Reset() + h.Write(r.K) + h.Write(r.H) + + if len(digestsSoFar) == 0 { + h.Write(tag) + h.Write(r.SessionID) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum(nil) + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +const packageVersion = "SSH-2.0-Go" + +// Sends and receives a version line. The versionLine string should +// be US ASCII, start with "SSH-2.0-", and should not include a +// newline. exchangeVersions returns the other side's version line. +func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { + // Contrary to the RFC, we do not ignore lines that don't + // start with "SSH-2.0-" to make the library usable with + // nonconforming servers. + for _, c := range versionLine { + // The spec disallows non US-ASCII chars, and + // specifically forbids null chars. + if c < 32 { + return nil, errors.New("ssh: junk character in version line") + } + } + if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { + return + } + + them, err = readVersion(rw) + return them, err +} + +// maxVersionStringBytes is the maximum number of bytes that we'll +// accept as a version string. RFC 4253 section 4.2 limits this at 255 +// chars +const maxVersionStringBytes = 255 + +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) ([]byte, error) { + versionString := make([]byte, 0, 64) + var ok bool + var buf [1]byte + + for len(versionString) < maxVersionStringBytes { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return nil, err + } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + ok = true + break + } + + // non ASCII chars are disallowed, but we are lenient, + // since Go doesn't use null-terminated strings. + + // The RFC allows a comment after a space, however, + // all of it (version and comments) goes into the + // session hash. + versionString = append(versionString, buf[0]) + } + + if !ok { + return nil, errors.New("ssh: overflow reading version string") + } + + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/transport_test.go b/vendor/golang.org/x/crypto/ssh/transport_test.go new file mode 100644 index 0000000..92d83ab --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/transport_test.go @@ -0,0 +1,109 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "strings" + "testing" +) + +func TestReadVersion(t *testing.T) { + longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] + cases := map[string]string{ + "SSH-2.0-bla\r\n": "SSH-2.0-bla", + "SSH-2.0-bla\n": "SSH-2.0-bla", + longversion + "\r\n": longversion, + } + + for in, want := range cases { + result, err := readVersion(bytes.NewBufferString(in)) + if err != nil { + t.Errorf("readVersion(%q): %s", in, err) + } + got := string(result) + if got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + +func TestReadVersionError(t *testing.T) { + longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] + cases := []string{ + longversion + "too-long\r\n", + } + for _, in := range cases { + if _, err := readVersion(bytes.NewBufferString(in)); err == nil { + t.Errorf("readVersion(%q) should have failed", in) + } + } +} + +func TestExchangeVersionsBasic(t *testing.T) { + v := "SSH-2.0-bla" + buf := bytes.NewBufferString(v + "\r\n") + them, err := exchangeVersions(buf, []byte("xyz")) + if err != nil { + t.Errorf("exchangeVersions: %v", err) + } + + if want := "SSH-2.0-bla"; string(them) != want { + t.Errorf("got %q want %q for our version", them, want) + } +} + +func TestExchangeVersions(t *testing.T) { + cases := []string{ + "not\x000allowed", + "not allowed\n", + } + for _, c := range cases { + buf := bytes.NewBufferString("SSH-2.0-bla\r\n") + if _, err := exchangeVersions(buf, []byte(c)); err == nil { + t.Errorf("exchangeVersions(%q): should have failed", c) + } + } +} + +type closerBuffer struct { + bytes.Buffer +} + +func (b *closerBuffer) Close() error { + return nil +} + +func TestTransportMaxPacketWrite(t *testing.T) { + buf := &closerBuffer{} + tr := newTransport(buf, rand.Reader, true) + huge := make([]byte, maxPacket+1) + err := tr.writePacket(huge) + if err == nil { + t.Errorf("transport accepted write for a huge packet.") + } +} + +func TestTransportMaxPacketReader(t *testing.T) { + var header [5]byte + huge := make([]byte, maxPacket+128) + binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) + // padding. + header[4] = 0 + + buf := &closerBuffer{} + buf.Write(header[:]) + buf.Write(huge) + + tr := newTransport(buf, rand.Reader, true) + _, err := tr.readPacket() + if err == nil { + t.Errorf("transport succeeded reading huge packet.") + } else if !strings.Contains(err.Error(), "large") { + t.Errorf("got %q, should mention %q", err.Error(), "large") + } +} From d12c0265b07fb5f7740b1da7fea4d3b0eee2713e Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 13 Dec 2016 12:27:20 +0100 Subject: [PATCH 045/104] sleeping 500ms after 'START SLAVE' --- go/logic/inspect.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index cd47c78..19db8a3 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -21,6 +21,8 @@ import ( "github.com/outbrain/golib/sqlutils" ) +const startSlavePostWaitMilliseconds = 500 * time.Millisecond + // Inspector reads data from the read-MySQL-server (typically a replica, but can be the master) // It is used for gaining initial status and structure, and later also follow up on progress and changelog type Inspector struct { @@ -258,6 +260,8 @@ func (this *Inspector) restartReplication() error { if startError != nil { return startError } + time.Sleep(startSlavePostWaitMilliseconds) + log.Debugf("Replication restarted") return nil } From ba2a9d9e554a678c2609e6586ceb1a2f36f04c00 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 13 Dec 2016 16:09:34 +0100 Subject: [PATCH 046/104] support for --master-user and --master-password --- go/base/context.go | 12 +++++++----- go/cmd/gh-ost/main.go | 9 +++++++-- go/logic/migrator.go | 6 ++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 606cb69..ef25eb0 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -82,11 +82,13 @@ type MigrationContext struct { IsTungsten bool DiscardForeignKeys bool - config ContextConfig - configMutex *sync.Mutex - ConfigFile string - CliUser string - CliPassword string + config ContextConfig + configMutex *sync.Mutex + ConfigFile string + CliUser string + CliPassword string + CliMasterUser string + CliMasterPassword string HeartbeatIntervalMilliseconds int64 defaultNumRetries int64 diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 63bb230..bb8b8d2 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -50,6 +50,8 @@ func main() { flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)") flag.StringVar(&migrationContext.CliUser, "user", "", "MySQL user") flag.StringVar(&migrationContext.CliPassword, "password", "", "MySQL password") + flag.StringVar(&migrationContext.CliMasterUser, "master-user", "", "MySQL user on master, if different from that on replica. Requires --assume-master-host") + flag.StringVar(&migrationContext.CliMasterPassword, "master-password", "", "MySQL password on master, if different from that on replica. Requires --assume-master-host") flag.StringVar(&migrationContext.ConfigFile, "conf", "", "Config file") askPass := flag.Bool("ask-pass", false, "prompt for MySQL password") @@ -170,8 +172,11 @@ func main() { } log.Warning("--test-on-replica-skip-replica-stop enabled. We will not stop replication before cut-over. Ensure you have a plugin that does this.") } - if migrationContext.AssumeMasterHostname != "" && !migrationContext.AllowedMasterMaster && !migrationContext.IsTungsten { - log.Fatalf("--assume-master-host requires either --allow-master-master or --tungsten") + if migrationContext.CliMasterUser != "" && migrationContext.AssumeMasterHostname == "" { + log.Fatalf("--master-user requires --assume-master-host") + } + if migrationContext.CliMasterPassword != "" && migrationContext.AssumeMasterHostname == "" { + log.Fatalf("--master-password requires --assume-master-host") } switch *cutOver { diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 8f61425..423ed27 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -653,6 +653,12 @@ func (this *Migrator) initiateInspector() (err error) { return err } this.migrationContext.ApplierConnectionConfig = this.migrationContext.InspectorConnectionConfig.DuplicateCredentials(*key) + if this.migrationContext.CliMasterUser != "" { + this.migrationContext.ApplierConnectionConfig.User = this.migrationContext.CliMasterUser + } + if this.migrationContext.CliMasterPassword != "" { + this.migrationContext.ApplierConnectionConfig.Password = this.migrationContext.CliMasterPassword + } log.Infof("Master forced to be %+v", *this.migrationContext.ApplierConnectionConfig.ImpliedKey) } // validate configs From 5af336e29b80927d86744e61beb7130809b18ed9 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 14 Dec 2016 10:38:58 +0100 Subject: [PATCH 047/104] updated release version following some changes I wish to release --- RELEASE_VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASE_VERSION b/RELEASE_VERSION index 475bda9..15245f3 100644 --- a/RELEASE_VERSION +++ b/RELEASE_VERSION @@ -1 +1 @@ -1.0.30 +1.0.32 From fc831b0548b33f62d34b6b6aadff0614507117c8 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 26 Dec 2016 21:31:35 +0200 Subject: [PATCH 048/104] Reading replication lag via _changelog_ table, also on control replicas --- go/cmd/gh-ost/main.go | 2 +- go/logic/inspect.go | 12 +++-- go/logic/throttler.go | 101 ++++++++++++++++++++++++++++++++---------- go/mysql/utils.go | 36 +-------------- 4 files changed, 84 insertions(+), 67 deletions(-) diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index bb8b8d2..58e231e 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -91,7 +91,7 @@ func main() { replicationLagQuery := flag.String("replication-lag-query", "", "Query that detects replication lag in seconds. Result can be a floating point (by default gh-ost issues SHOW SLAVE STATUS and reads Seconds_behind_master). If you're using pt-heartbeat, query would be something like: SELECT ROUND(UNIX_TIMESTAMP() - MAX(UNIX_TIMESTAMP(ts))) AS delay FROM my_schema.heartbeat") throttleControlReplicas := flag.String("throttle-control-replicas", "", "List of replicas on which to check for lag; comma delimited. Example: myhost1.com:3306,myhost2.com,myhost3.com:3307") throttleQuery := flag.String("throttle-query", "", "when given, issued (every second) to check if operation should throttle. Expecting to return zero for no-throttle, >0 for throttle. Query is issued on the migrated server. Make sure this query is lightweight") - heartbeatIntervalMillis := flag.Int64("heartbeat-interval-millis", 500, "how frequently would gh-ost inject a heartbeat value") + heartbeatIntervalMillis := flag.Int64("heartbeat-interval-millis", 100, "how frequently would gh-ost inject a heartbeat value") flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered") flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-ost.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-ost operations") flag.StringVar(&migrationContext.PostponeCutOverFlagFile, "postpone-cut-over-flag-file", "", "while this file exists, migration will postpone the final stage of swapping tables, and will keep on syncing the ghost table. Cut-over/swapping would be ready to perform the moment the file is deleted.") diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 19db8a3..6c81e99 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -680,18 +680,18 @@ func (this *Inspector) showCreateTable(tableName string) (createTableStatement s } // readChangelogState reads changelog hints -func (this *Inspector) readChangelogState() (map[string]string, error) { +func (this *Inspector) readChangelogState(hint string) (string, error) { query := fmt.Sprintf(` - select hint, value from %s.%s where id <= 255 + select hint, value from %s.%s where hint = ? and id <= 255 `, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) - result := make(map[string]string) + result := "" err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { - result[m.GetString("hint")] = m.GetString("value") + result = m.GetString("value") return nil - }) + }, hint) return result, err } @@ -702,10 +702,8 @@ func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.Connect } func (this *Inspector) getReplicationLag() (replicationLag time.Duration, err error) { - replicationLagQuery := this.migrationContext.GetReplicationLagQuery() replicationLag, err = mysql.GetReplicationLag( this.migrationContext.InspectorConnectionConfig, - replicationLagQuery, ) return replicationLag, err } diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 482087c..1635b5e 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -12,7 +12,9 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/mysql" + "github.com/github/gh-ost/go/sql" "github.com/outbrain/golib/log" + "github.com/outbrain/golib/sqlutils" ) // Throttler collects metrics related to throttling and makes informed decisison @@ -62,15 +64,24 @@ func (this *Throttler) shouldThrottle() (result bool, reason string, reasonHint return false, "", base.NoThrottleReasonHint } -// parseChangelogHeartbeat is called when a heartbeat event is intercepted -func (this *Throttler) parseChangelogHeartbeat(heartbeatValue string) (err error) { +// parseChangelogHeartbeat parses a string timestamp and deduces replication lag +func parseChangelogHeartbeat(heartbeatValue string) (lag time.Duration, err error) { heartbeatTime, err := time.Parse(time.RFC3339Nano, heartbeatValue) if err != nil { - return log.Errore(err) + return lag, err + } + lag = time.Since(heartbeatTime) + return lag, nil +} + +// parseChangelogHeartbeat parses a string timestamp and deduces replication lag +func (this *Throttler) parseChangelogHeartbeat(heartbeatValue string) (err error) { + if lag, err := parseChangelogHeartbeat(heartbeatValue); err != nil { + return log.Errore(err) + } else { + atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) + return nil } - lag := time.Since(heartbeatTime) - atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) - return nil } // collectHeartbeat reads the latest changelog heartbeat value @@ -81,11 +92,9 @@ func (this *Throttler) collectHeartbeat() { if atomic.LoadInt64(&this.migrationContext.CleanupImminentFlag) > 0 { return nil } - changelogState, err := this.inspector.readChangelogState() - if err != nil { + if heartbeatValue, err := this.inspector.readChangelogState("heartbeat"); err != nil { return log.Errore(err) - } - if heartbeatValue, ok := changelogState["heartbeat"]; ok { + } else { this.parseChangelogHeartbeat(heartbeatValue) } return nil @@ -95,23 +104,68 @@ func (this *Throttler) collectHeartbeat() { // collectControlReplicasLag polls all the control replicas to get maximum lag value func (this *Throttler) collectControlReplicasLag() { - readControlReplicasLag := func(replicationLagQuery string) error { - if (this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica) && (atomic.LoadInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag) > 0) { - return nil + + replicationLagQuery := fmt.Sprintf(` + select value from %s.%s where hint = 'heartbeat' and id <= 255 + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + + readReplicaLag := func(connectionConfig *mysql.ConnectionConfig) (lag time.Duration, err error) { + dbUri := connectionConfig.GetDBUri("information_schema") + var heartbeatValue string + if db, _, err := sqlutils.GetDB(dbUri); err != nil { + return lag, err + } else if err = db.QueryRow(replicationLagQuery).Scan(&heartbeatValue); err != nil { + return lag, err + } + lag, err = parseChangelogHeartbeat(heartbeatValue) + return lag, err + } + + readControlReplicasLag := func() (result *mysql.ReplicationLagResult) { + instanceKeyMap := this.migrationContext.GetThrottleControlReplicaKeys() + if instanceKeyMap.Len() == 0 { + return result + } + lagResults := make(chan *mysql.ReplicationLagResult, instanceKeyMap.Len()) + for replicaKey := range *instanceKeyMap { + connectionConfig := this.migrationContext.InspectorConnectionConfig.Duplicate() + connectionConfig.Key = replicaKey + + lagResult := &mysql.ReplicationLagResult{Key: connectionConfig.Key} + go func() { + lagResult.Lag, lagResult.Err = readReplicaLag(connectionConfig) + lagResults <- lagResult + }() + } + for range *instanceKeyMap { + lagResult := <-lagResults + if result == nil { + result = lagResult + } else if lagResult.Err != nil { + result = lagResult + } else if lagResult.Lag.Nanoseconds() > result.Lag.Nanoseconds() { + result = lagResult + } + } + return result + } + + checkControlReplicasLag := func() { + if (this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica) && (atomic.LoadInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag) > 0) { + // No need to read lag + return + } + if result := readControlReplicasLag(); result != nil { + this.migrationContext.SetControlReplicasLagResult(result) } - lagResult := mysql.GetMaxReplicationLag( - this.migrationContext.InspectorConnectionConfig, - this.migrationContext.GetThrottleControlReplicaKeys(), - replicationLagQuery, - ) - this.migrationContext.SetControlReplicasLagResult(lagResult) - return nil } aggressiveTicker := time.Tick(100 * time.Millisecond) relaxedFactor := 10 counter := 0 shouldReadLagAggressively := false - replicationLagQuery := "" for range aggressiveTicker { if counter%relaxedFactor == 0 { @@ -119,12 +173,11 @@ func (this *Throttler) collectControlReplicasLag() { // do not typically change at all throughout the migration, but nonetheless we check them. counter = 0 maxLagMillisecondsThrottleThreshold := atomic.LoadInt64(&this.migrationContext.MaxLagMillisecondsThrottleThreshold) - replicationLagQuery = this.migrationContext.GetReplicationLagQuery() - shouldReadLagAggressively = (replicationLagQuery != "" && maxLagMillisecondsThrottleThreshold < 1000) + shouldReadLagAggressively = (maxLagMillisecondsThrottleThreshold < 1000) } if counter == 0 || shouldReadLagAggressively { // We check replication lag every so often, or if we wish to be aggressive - readControlReplicasLag(replicationLagQuery) + checkControlReplicasLag() } counter++ } diff --git a/go/mysql/utils.go b/go/mysql/utils.go index 6b70c7c..bfd3c24 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -24,19 +24,13 @@ type ReplicationLagResult struct { // GetReplicationLag returns replication lag for a given connection config; either by explicit query // or via SHOW SLAVE STATUS -func GetReplicationLag(connectionConfig *ConnectionConfig, replicationLagQuery string) (replicationLag time.Duration, err error) { +func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time.Duration, err error) { dbUri := connectionConfig.GetDBUri("information_schema") var db *gosql.DB if db, _, err = sqlutils.GetDB(dbUri); err != nil { return replicationLag, err } - if replicationLagQuery != "" { - var floatLag float64 - err = db.QueryRow(replicationLagQuery).Scan(&floatLag) - return time.Duration(int64(floatLag*1000)) * time.Millisecond, err - } - // No explicit replication lag query. err = sqlutils.QueryRowsMap(db, `show slave status`, func(m sqlutils.RowMap) error { secondsBehindMaster := m.GetNullInt64("Seconds_Behind_Master") if !secondsBehindMaster.Valid { @@ -48,34 +42,6 @@ func GetReplicationLag(connectionConfig *ConnectionConfig, replicationLagQuery s return replicationLag, err } -// GetMaxReplicationLag concurrently checks for replication lag on given list of instance keys, -// each via GetReplicationLag -func GetMaxReplicationLag(baseConnectionConfig *ConnectionConfig, instanceKeyMap *InstanceKeyMap, replicationLagQuery string) (result *ReplicationLagResult) { - result = &ReplicationLagResult{Lag: 0} - if instanceKeyMap.Len() == 0 { - return result - } - lagResults := make(chan *ReplicationLagResult, instanceKeyMap.Len()) - for key := range *instanceKeyMap { - connectionConfig := baseConnectionConfig.Duplicate() - connectionConfig.Key = key - result := &ReplicationLagResult{Key: connectionConfig.Key} - go func() { - result.Lag, result.Err = GetReplicationLag(connectionConfig, replicationLagQuery) - lagResults <- result - }() - } - for range *instanceKeyMap { - lagResult := <-lagResults - if lagResult.Err != nil { - result = lagResult - } else if lagResult.Lag.Nanoseconds() > result.Lag.Nanoseconds() { - result = lagResult - } - } - return result -} - func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey *InstanceKey, err error) { currentUri := connectionConfig.GetDBUri("information_schema") db, _, err := sqlutils.GetDB(currentUri) From c66356bd05d6263c6706acda148658e2954d8fc7 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 26 Dec 2016 21:38:37 +0200 Subject: [PATCH 049/104] --replication-lag-query is deprecated. --- go/base/context.go | 18 ------------------ go/cmd/gh-ost/main.go | 6 ++++-- go/logic/migrator.go | 5 ----- go/logic/server.go | 3 +-- 4 files changed, 5 insertions(+), 27 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index ef25eb0..36e279d 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -95,7 +95,6 @@ type MigrationContext struct { ChunkSize int64 niceRatio float64 MaxLagMillisecondsThrottleThreshold int64 - replicationLagQuery string throttleControlReplicaKeys *mysql.InstanceKeyMap ThrottleFlagFile string ThrottleAdditionalFlagFile string @@ -455,23 +454,6 @@ func (this *MigrationContext) IsThrottled() (bool, string, ThrottleReasonHint) { return this.isThrottled, this.throttleReason, this.throttleReasonHint } -func (this *MigrationContext) GetReplicationLagQuery() string { - var query string - - this.throttleMutex.Lock() - defer this.throttleMutex.Unlock() - - query = this.replicationLagQuery - return query -} - -func (this *MigrationContext) SetReplicationLagQuery(newQuery string) { - this.throttleMutex.Lock() - defer this.throttleMutex.Unlock() - - this.replicationLagQuery = newQuery -} - func (this *MigrationContext) GetThrottleQuery() string { var query string diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 58e231e..bccef88 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -88,7 +88,7 @@ func main() { niceRatio := flag.Float64("nice-ratio", 0, "force being 'nice', imply sleep time per chunk time; range: [0.0..100.0]. Example values: 0 is aggressive. 1: for every 1ms spent copying rows, sleep additional 1ms (effectively doubling runtime); 0.7: for every 10ms spend in a rowcopy chunk, spend 7ms sleeping immediately after") maxLagMillis := flag.Int64("max-lag-millis", 1500, "replication lag at which to throttle operation") - replicationLagQuery := flag.String("replication-lag-query", "", "Query that detects replication lag in seconds. Result can be a floating point (by default gh-ost issues SHOW SLAVE STATUS and reads Seconds_behind_master). If you're using pt-heartbeat, query would be something like: SELECT ROUND(UNIX_TIMESTAMP() - MAX(UNIX_TIMESTAMP(ts))) AS delay FROM my_schema.heartbeat") + replicationLagQuery := flag.String("replication-lag-query", "", "Deprecated. gh-ost uses an internal, subsecond resolution query") throttleControlReplicas := flag.String("throttle-control-replicas", "", "List of replicas on which to check for lag; comma delimited. Example: myhost1.com:3306,myhost2.com,myhost3.com:3307") throttleQuery := flag.String("throttle-query", "", "when given, issued (every second) to check if operation should throttle. Expecting to return zero for no-throttle, >0 for throttle. Query is issued on the migrated server. Make sure this query is lightweight") heartbeatIntervalMillis := flag.Int64("heartbeat-interval-millis", 100, "how frequently would gh-ost inject a heartbeat value") @@ -178,6 +178,9 @@ func main() { if migrationContext.CliMasterPassword != "" && migrationContext.AssumeMasterHostname == "" { log.Fatalf("--master-password requires --assume-master-host") } + if *replicationLagQuery != "" { + log.Warningf("--replication-lag-query is deprecated") + } switch *cutOver { case "atomic", "default", "": @@ -214,7 +217,6 @@ func main() { migrationContext.SetNiceRatio(*niceRatio) migrationContext.SetChunkSize(*chunkSize) migrationContext.SetMaxLagMillisecondsThrottleThreshold(*maxLagMillis) - migrationContext.SetReplicationLagQuery(*replicationLagQuery) migrationContext.SetThrottleQuery(*throttleQuery) migrationContext.SetDefaultNumRetries(*defaultRetries) migrationContext.ApplyCredentials() diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 423ed27..2477989 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -723,11 +723,6 @@ func (this *Migrator) printMigrationStatusHint(writers ...io.Writer) { criticalLoad.String(), this.migrationContext.GetNiceRatio(), )) - if replicationLagQuery := this.migrationContext.GetReplicationLagQuery(); replicationLagQuery != "" { - fmt.Fprintln(w, fmt.Sprintf("# replication-lag-query: %+v", - replicationLagQuery, - )) - } if this.migrationContext.ThrottleFlagFile != "" { setIndicator := "" if base.FileExists(this.migrationContext.ThrottleFlagFile) { diff --git a/go/logic/server.go b/go/logic/server.go index b8c0f2c..f616f60 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -178,8 +178,7 @@ help # This message } case "replication-lag-query": { - this.migrationContext.SetReplicationLagQuery(arg) - return ForcePrintStatusAndHintRule, nil + return NoPrintStatusRule, fmt.Errorf("replication-lag-query is deprecated. gh-ost uses an internal, subsecond resolution query") } case "nice-ratio": { From 21cf1ffdc5f67d36ad71eef4ffdab915301cfd1a Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 27 Dec 2016 08:48:34 +0200 Subject: [PATCH 050/104] updated documentation to match --replication-lag-query deprecation --- doc/command-line-flags.md | 5 +---- doc/subsecond-lag.md | 22 +++++----------------- doc/throttle.md | 10 +--------- 3 files changed, 7 insertions(+), 30 deletions(-) diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index 1707685..d7bc97b 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -100,10 +100,7 @@ On a replication topology, this is perhaps the most important migration throttli When using [Connect to replica, migrate on master](cheatsheet.md), this lag is primarily tested on the very replica `gh-ost` operates on. Lag is measured by checking the heartbeat events injected by `gh-ost` itself on the utility changelog table. That is, to measure this replica's lag, `gh-ost` doesn't need to issue `show slave status` nor have any external heartbeat mechanism. -When `--throttle-control-replicas` is provided, throttling also considers lag on specified hosts. Measuring lag on these hosts works as follows: - -- If `--replication-lag-query` is provided, use the query, trust its result to indicate lag seconds (fraction, i.e. float, allowed) -- Otherwise, issue `show slave status` and read `Seconds_behind_master` (`1sec` granularity) +When `--throttle-control-replicas` is provided, throttling also considers lag on specified hosts. Lag measurements on listed hosts is done by querying `gh-ost`'s _changelog_ table, where `gh-ost` injects a heartbeat. See also: [Sub-second replication lag throttling](subsecond-lag.md) diff --git a/doc/subsecond-lag.md b/doc/subsecond-lag.md index 00f67cb..85ff941 100644 --- a/doc/subsecond-lag.md +++ b/doc/subsecond-lag.md @@ -2,7 +2,7 @@ `gh-ost` is able to utilize sub-second replication lag measurements. -At GitHub, small replication lag is crucial, and we like to keep it below `1s` at all times. If you have similar concern, we strongly urge you to proceed to implement sub-second lag throttling. +At GitHub, small replication lag is crucial, and we like to keep it below `1s` at all times. `gh-ost` will do sub-second throttling when `--max-lag-millis` is smaller than `1000`, i.e. smaller than `1sec`. Replication lag is measured on: @@ -10,24 +10,12 @@ Replication lag is measured on: - The "inspected" server (the server `gh-ost` connects to; replica is desired but not mandatory) - The `throttle-control-replicas` list -For the inspected server, `gh-ost` uses an internal heartbeat mechanism. It injects heartbeat events onto the utility changelog table, then reads those events in the binary log, and compares times. This measurement is by default and by definition sub-second enabled. +`gh-ost` uses an internal heartbeat mechanism. It injects heartbeat events onto the utility changelog table, then reads those events in the binary log, and compares times. This measurement is on by default and by definition supports sub-second resolution. You can explicitly define how frequently will `gh-ost` inject heartbeat events, via `heartbeat-interval-millis`. You should set `heartbeat-interval-millis <= max-lag-millis`. It still works if not, but loses granularity and effect. -On the `throttle-control-replicas`, `gh-ost` only issues SQL queries, and does not attempt to read the binary log stream. Perhaps those other replicas don't have binary logs in the first place. +`gh-ost` will use its own internal heartbeat to measure lag both on the inspected replica as well as on the `--throttle-control-replicas` list, if provided. -The standard way of getting replication lag on a replica is to issue `SHOW SLAVE STATUS`, then reading `Seconds_behind_master` value. But that value has a `1sec` granularity. +In earlier versions, the `--throttle-control-replicas` list was subjected to `1` second resolution or to 3rd party heartbeat injections such as `pt-heartbeat`. This is no longer the case. The argument `--replication-lag-query` has been deprecated and no longer needed. -To be able to throttle on your production replicas fleet when replication lag exceeds a sub-second threshold, you must provide with a `replication-lag-query` that returns a sub-second resolution lag. - -As a common example, many use [pt-heartbeat](https://www.percona.com/doc/percona-toolkit/2.2/pt-heartbeat.html) to inject heartbeat events on the master. You would issue something like: - - /usr/bin/pt-heartbeat -- -D your_schema --create-table --update --replace --interval=0.1 --daemonize --pid ... - -Note `--interval=0.1` to indicate `10` heartbeats per second. - -You would then provide - - gh-ost ... --replication-lag-query="select unix_timestamp(now(6)) - unix_timestamp(ts) as ghost_lag_check from your_schema.heartbeat order by ts desc limit 1" - -Our production migrations use sub-second lag throttling and are able to keep our entire fleet of replicas well below `1sec` lag. +Our production migrations use sub-second lag throttling and are able to keep our entire fleet of replicas well below `1sec` lag. We use `--heartbeat-interval-millis=100` on our production migrations with a `--mas-lag-millis` value of between `300` and `500`. diff --git a/doc/throttle.md b/doc/throttle.md index 94e12ab..bc4a315 100644 --- a/doc/throttle.md +++ b/doc/throttle.md @@ -28,15 +28,7 @@ Otherwise you may specify your own list of replica servers you wish it to observ - `--max-lag-millis`: maximum allowed lag; any controlled replica lagging more than this value will cause throttling to kick in. When all control replicas have smaller lag than indicated, operation resumes. -- `--replication-lag-query`: `gh-ost` will, by default, issue a `show slave status` query to find replication lag. However, this is a notoriously flaky value. If you're using your own `heartbeat` mechanism, e.g. via [`pt-heartbeat`](https://www.percona.com/doc/percona-toolkit/2.2/pt-heartbeat.html), you may provide your own custom query to return a single decimal (floating point) value indicating replication lag. - - Example: `--replication-lag-query="SELECT UNIX_TIMESTAMP() - MAX(UNIX_TIMESTAMP(ts)) AS lag FROM mydb.heartbeat"` - - We encourage you to use [sub-second replication lag throttling](subsecond-lag.md). Your query may then look like: - - `--replication-lag-query="SELECT UNIX_TIMESTAMP(6) - MAX(UNIX_TIMESTAMP(ts)) AS lag FROM mydb.heartbeat"` - -Note that you may dynamically change both `replication-lag-query` and the `throttle-control-replicas` list via [interactive commands](interactive-commands.md) +Note that you may dynamically change both `--max-lag-millis` and the `throttle-control-replicas` list via [interactive commands](interactive-commands.md) #### Status thresholds From 9f11c6e89e66424ec7b7de1c6b38fae75f60661c Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 27 Dec 2016 09:14:11 +0200 Subject: [PATCH 051/104] docs: no longer using binary logs for heartbeat --- doc/subsecond-lag.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/subsecond-lag.md b/doc/subsecond-lag.md index 85ff941..a966aa0 100644 --- a/doc/subsecond-lag.md +++ b/doc/subsecond-lag.md @@ -10,7 +10,7 @@ Replication lag is measured on: - The "inspected" server (the server `gh-ost` connects to; replica is desired but not mandatory) - The `throttle-control-replicas` list -`gh-ost` uses an internal heartbeat mechanism. It injects heartbeat events onto the utility changelog table, then reads those events in the binary log, and compares times. This measurement is on by default and by definition supports sub-second resolution. +`gh-ost` uses an internal heartbeat mechanism. It injects heartbeat events onto the utility changelog table, then reads those entries on replicas, and compares times. This measurement is on by default and by definition supports sub-second resolution. You can explicitly define how frequently will `gh-ost` inject heartbeat events, via `heartbeat-interval-millis`. You should set `heartbeat-interval-millis <= max-lag-millis`. It still works if not, but loses granularity and effect. From 9be3c099eeaa18fa56790a00353e4be48d1815d5 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 27 Dec 2016 09:15:32 +0200 Subject: [PATCH 052/104] removed redundant paragraph --- doc/subsecond-lag.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/doc/subsecond-lag.md b/doc/subsecond-lag.md index a966aa0..c786ca2 100644 --- a/doc/subsecond-lag.md +++ b/doc/subsecond-lag.md @@ -10,12 +10,10 @@ Replication lag is measured on: - The "inspected" server (the server `gh-ost` connects to; replica is desired but not mandatory) - The `throttle-control-replicas` list -`gh-ost` uses an internal heartbeat mechanism. It injects heartbeat events onto the utility changelog table, then reads those entries on replicas, and compares times. This measurement is on by default and by definition supports sub-second resolution. +In both cases, `gh-ost` uses an internal heartbeat mechanism. It injects heartbeat events onto the utility changelog table, then reads those entries on replicas, and compares times. This measurement is on by default and by definition supports sub-second resolution. You can explicitly define how frequently will `gh-ost` inject heartbeat events, via `heartbeat-interval-millis`. You should set `heartbeat-interval-millis <= max-lag-millis`. It still works if not, but loses granularity and effect. -`gh-ost` will use its own internal heartbeat to measure lag both on the inspected replica as well as on the `--throttle-control-replicas` list, if provided. - In earlier versions, the `--throttle-control-replicas` list was subjected to `1` second resolution or to 3rd party heartbeat injections such as `pt-heartbeat`. This is no longer the case. The argument `--replication-lag-query` has been deprecated and no longer needed. Our production migrations use sub-second lag throttling and are able to keep our entire fleet of replicas well below `1sec` lag. We use `--heartbeat-interval-millis=100` on our production migrations with a `--mas-lag-millis` value of between `300` and `500`. From 3bf64d8280b7cd639c95f748ccff02e90a7f4345 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 27 Dec 2016 09:16:27 +0200 Subject: [PATCH 053/104] typo --- doc/subsecond-lag.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/subsecond-lag.md b/doc/subsecond-lag.md index c786ca2..564d337 100644 --- a/doc/subsecond-lag.md +++ b/doc/subsecond-lag.md @@ -16,4 +16,4 @@ You can explicitly define how frequently will `gh-ost` inject heartbeat events, In earlier versions, the `--throttle-control-replicas` list was subjected to `1` second resolution or to 3rd party heartbeat injections such as `pt-heartbeat`. This is no longer the case. The argument `--replication-lag-query` has been deprecated and no longer needed. -Our production migrations use sub-second lag throttling and are able to keep our entire fleet of replicas well below `1sec` lag. We use `--heartbeat-interval-millis=100` on our production migrations with a `--mas-lag-millis` value of between `300` and `500`. +Our production migrations use sub-second lag throttling and are able to keep our entire fleet of replicas well below `1sec` lag. We use `--heartbeat-interval-millis=100` on our production migrations with a `--max-lag-millis` value of between `300` and `500`. From 2bfc672b030a0ede1edc192dc53ea0d67b82c886 Mon Sep 17 00:00:00 2001 From: Cyprian Nosek Date: Wed, 28 Dec 2016 10:25:42 +0100 Subject: [PATCH 054/104] Add possibility to set serverId as gh-ost argument Signed-off-by: Cyprian Nosek --- go/base/context.go | 1 + go/binlog/gomysql_reader.go | 9 +++++---- go/cmd/gh-ost/main.go | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index ef25eb0..0f74f97 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -123,6 +123,7 @@ type MigrationContext struct { InitiallyDropOldTable bool InitiallyDropGhostTable bool CutOverType CutOver + ReplicaServerId uint Hostname string AssumeMasterHostname string diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index 6d338ca..7b57a99 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -9,6 +9,7 @@ import ( "fmt" "sync" + "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" @@ -17,10 +18,6 @@ import ( "github.com/siddontang/go-mysql/replication" ) -const ( - serverId = 99999 -) - type GoMySQLReader struct { connectionConfig *mysql.ConnectionConfig binlogSyncer *replication.BinlogSyncer @@ -28,6 +25,7 @@ type GoMySQLReader struct { currentCoordinates mysql.BinlogCoordinates currentCoordinatesMutex *sync.Mutex LastAppliedRowsEventHint mysql.BinlogCoordinates + MigrationContext *base.MigrationContext } func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *GoMySQLReader, err error) { @@ -37,7 +35,10 @@ func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *G currentCoordinatesMutex: &sync.Mutex{}, binlogSyncer: nil, binlogStreamer: nil, + MigrationContext: base.GetMigrationContext(), } + + serverId := uint32(binlogReader.MigrationContext.ReplicaServerId) binlogReader.binlogSyncer = replication.NewBinlogSyncer(serverId, "mysql") return binlogReader, err diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index bb8b8d2..6491f57 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -104,6 +104,8 @@ func main() { flag.StringVar(&migrationContext.HooksPath, "hooks-path", "", "directory where hook files are found (default: empty, ie. hooks disabled). Hook files found on this path, and conforming to hook naming conventions will be executed") flag.StringVar(&migrationContext.HooksHintMessage, "hooks-hint", "", "arbitrary message to be injected to hooks via GH_OST_HOOKS_HINT, for your convenience") + flag.UintVar(&migrationContext.ReplicaServerId, "replica-server-id", 99999, "server id used by gh-ost process. Default: 99999") + maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'. When status exceeds threshold, app throttles writes") criticalLoad := flag.String("critical-load", "", "Comma delimited status-name=threshold, same format as --max-load. When status exceeds threshold, app panics and quits") flag.Int64Var(&migrationContext.CriticalLoadIntervalMilliseconds, "critical-load-interval-millis", 0, "When 0, migration immediately bails out upon meeting critical-load. When non-zero, a second check is done after given interval, and migration only bails out if 2nd check still meets critical load") From 9dd8d858a17c600152c8b1a4bd97fc6e0c2b61de Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 29 Dec 2016 17:39:44 +0200 Subject: [PATCH 055/104] cli: --check-flag --- go/cmd/gh-ost/main.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 6491f57..fdfa919 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -115,8 +115,13 @@ func main() { stack := flag.Bool("stack", false, "add stack trace upon error") help := flag.Bool("help", false, "Display usage") version := flag.Bool("version", false, "Print version & exit") + checkFlag := flag.Bool("check-flag", false, "Check if another flag exists/supported. This allows for cross-version scripting. Exits with 0 when all additional provided flags exist, nonzero otherwise. You must provide (dummy) values for flags that require a value. Example: gh-ost --check-flag --cut-over-lock-timeout-seconds --nice-ratio 0") + flag.Parse() + if *checkFlag { + return + } if *help { fmt.Fprintf(os.Stderr, "Usage of gh-ost:\n") flag.PrintDefaults() From 3e28f462d88c9eb9cc84a26b6228fd278170fdb3 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 3 Jan 2017 13:44:52 +0200 Subject: [PATCH 056/104] Initial support for batching multiple DMLs when writing to ghost table --- go/logic/applier.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/go/logic/applier.go b/go/logic/applier.go index afff578..8fb940a 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -950,3 +950,48 @@ func (this *Applier) ApplyDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) error { } return nil } + +// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table +func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { + + var totalDelta int64 + + err := func() error { + tx, err := this.db.Begin() + if err != nil { + return err + } + sessionQuery := `SET + SESSION time_zone = '+00:00', + sql_mode = CONCAT(@@session.sql_mode, ',STRICT_ALL_TABLES') + ` + if _, err := tx.Exec(sessionQuery); err != nil { + return err + } + for _, dmlEvent := range dmlEvents { + query, args, rowDelta, err := this.buildDMLEventQuery(dmlEvent) + if err != nil { + return err + } + if _, err := tx.Exec(query, args...); err != nil { + err = fmt.Errorf("%s; query=%s; args=%+v", err.Error(), query, args) + return err + } + totalDelta += rowDelta + } + if err := tx.Commit(); err != nil { + return err + } + return nil + }() + + if err != nil { + return log.Errore(err) + } + // no error + atomic.AddInt64(&this.migrationContext.TotalDMLEventsApplied, int64(len(dmlEvents))) + if this.migrationContext.CountTableRows { + atomic.AddInt64(&this.migrationContext.RowsDeltaEstimate, totalDelta) + } + return nil +} From 445c67635dcaa4a42467a3f9ec86556bb6484ca1 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 3 Jan 2017 14:31:19 +0200 Subject: [PATCH 057/104] batching DML writes, configurable --dml-batch-size --- go/base/context.go | 12 +++++++++ go/cmd/gh-ost/main.go | 2 ++ go/logic/applier.go | 1 + go/logic/migrator.go | 62 ++++++++++++++++++++++++++++++++++--------- 4 files changed, 65 insertions(+), 12 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 0f74f97..8775408 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -149,6 +149,7 @@ type MigrationContext struct { controlReplicasLagResult mysql.ReplicationLagResult TotalRowsCopied int64 TotalDMLEventsApplied int64 + DMLBatchSize int64 isThrottled bool throttleReason string throttleReasonHint ThrottleReasonHint @@ -208,6 +209,7 @@ func newMigrationContext() *MigrationContext { ApplierConnectionConfig: mysql.NewConnectionConfig(), MaxLagMillisecondsThrottleThreshold: 1500, CutOverLockTimeoutSeconds: 3, + DMLBatchSize: 10, maxLoad: NewLoadMap(), criticalLoad: NewLoadMap(), throttleMutex: &sync.Mutex{}, @@ -418,6 +420,16 @@ func (this *MigrationContext) SetChunkSize(chunkSize int64) { atomic.StoreInt64(&this.ChunkSize, chunkSize) } +func (this *MigrationContext) SetDMLBatchSize(batchSize int64) { + if batchSize < 1 { + batchSize = 1 + } + if batchSize > 100 { + batchSize = 100 + } + atomic.StoreInt64(&this.DMLBatchSize, batchSize) +} + func (this *MigrationContext) SetThrottleGeneralCheckResult(checkResult *ThrottleCheckResult) *ThrottleCheckResult { this.throttleMutex.Lock() defer this.throttleMutex.Unlock() diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index fdfa919..4986d1e 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -83,6 +83,7 @@ func main() { flag.BoolVar(&migrationContext.SwitchToRowBinlogFormat, "switch-to-rbr", false, "let this tool automatically switch binary log format to 'ROW' on the replica, if needed. The format will NOT be switched back. I'm too scared to do that, and wish to protect you if you happen to execute another migration while this one is running") flag.BoolVar(&migrationContext.AssumeRBR, "assume-rbr", false, "set to 'true' when you know for certain your server uses 'ROW' binlog_format. gh-ost is unable to tell, event after reading binlog_format, whether the replication process does indeed use 'ROW', and restarts replication to be certain RBR setting is applied. Such operation requires SUPER privileges which you might not have. Setting this flag avoids restarting replication and you can proceed to use gh-ost without SUPER privileges") chunkSize := flag.Int64("chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 100-100,000)") + dmlBatchSize := flag.Int64("dml-batch-size", 10, "batch size for DML events to apply in a single transaction (range 1-100)") defaultRetries := flag.Int64("default-retries", 60, "Default number of retries for various operations before panicking") cutOverLockTimeoutSeconds := flag.Int64("cut-over-lock-timeout-seconds", 3, "Max number of seconds to hold locks on tables while attempting to cut-over (retry attempted when lock exceeds timeout)") niceRatio := flag.Float64("nice-ratio", 0, "force being 'nice', imply sleep time per chunk time; range: [0.0..100.0]. Example values: 0 is aggressive. 1: for every 1ms spent copying rows, sleep additional 1ms (effectively doubling runtime); 0.7: for every 10ms spend in a rowcopy chunk, spend 7ms sleeping immediately after") @@ -220,6 +221,7 @@ func main() { migrationContext.SetHeartbeatIntervalMilliseconds(*heartbeatIntervalMillis) migrationContext.SetNiceRatio(*niceRatio) migrationContext.SetChunkSize(*chunkSize) + migrationContext.SetDMLBatchSize(*dmlBatchSize) migrationContext.SetMaxLagMillisecondsThrottleThreshold(*maxLagMillis) migrationContext.SetReplicationLagQuery(*replicationLagQuery) migrationContext.SetThrottleQuery(*throttleQuery) diff --git a/go/logic/applier.go b/go/logic/applier.go index 8fb940a..d530d11 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -993,5 +993,6 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) if this.migrationContext.CountTableRows { atomic.AddInt64(&this.migrationContext.RowsDeltaEstimate, totalDelta) } + log.Debugf("ApplyDMLEventQueries() applied %d events in one transaction", len(dmlEvents)) return nil } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 423ed27..5db089a 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -37,6 +37,21 @@ func ReadChangelogState(s string) ChangelogState { type tableWriteFunc func() error +type applyEventStruct struct { + writeFunc *tableWriteFunc + dmlEvent *binlog.BinlogDMLEvent +} + +func newApplyEventStructByFunc(writeFunc *tableWriteFunc) *applyEventStruct { + result := &applyEventStruct{writeFunc: writeFunc} + return result +} + +func newApplyEventStructByDML(dmlEvent *binlog.BinlogDMLEvent) *applyEventStruct { + result := &applyEventStruct{dmlEvent: dmlEvent} + return result +} + const ( applyEventsQueueBuffer = 100 ) @@ -71,7 +86,7 @@ type Migrator struct { // copyRowsQueue should not be buffered; if buffered some non-damaging but // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete copyRowsQueue chan tableWriteFunc - applyEventsQueue chan tableWriteFunc + applyEventsQueue chan *applyEventStruct handledChangelogStates map[string]bool } @@ -86,7 +101,7 @@ func NewMigrator() *Migrator { allEventsUpToLockProcessed: make(chan string), copyRowsQueue: make(chan tableWriteFunc), - applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), + applyEventsQueue: make(chan *applyEventStruct, applyEventsQueueBuffer), handledChangelogStates: make(map[string]bool), } return migrator @@ -194,7 +209,7 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er } case AllEventsUpToLockProcessed: { - applyEventFunc := func() error { + var applyEventFunc tableWriteFunc = func() error { this.allEventsUpToLockProcessed <- changelogStateString return nil } @@ -204,7 +219,7 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er // So as not to create a potential deadlock, we write this func to applyEventsQueue // asynchronously, understanding it doesn't really matter. go func() { - this.applyEventsQueue <- applyEventFunc + this.applyEventsQueue <- newApplyEventStructByFunc(&applyEventFunc) }() } default: @@ -917,11 +932,7 @@ func (this *Migrator) addDMLEventsListener() error { this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, func(dmlEvent *binlog.BinlogDMLEvent) error { - // Create a task to apply the DML event; this will be execute by executeWriteFuncs() - applyEventFunc := func() error { - return this.applier.ApplyDMLEventQuery(dmlEvent) - } - this.applyEventsQueue <- applyEventFunc + this.applyEventsQueue <- newApplyEventStructByDML(dmlEvent) return nil }, ) @@ -1032,10 +1043,37 @@ func (this *Migrator) executeWriteFuncs() error { // We give higher priority to event processing, then secondary priority to // rowcopy select { - case applyEventFunc := <-this.applyEventsQueue: + case applyEventStruct := <-this.applyEventsQueue: { - if err := this.retryOperation(applyEventFunc); err != nil { - return log.Errore(err) + if applyEventStruct.writeFunc != nil { + if err := this.retryOperation(*applyEventStruct.writeFunc); err != nil { + return log.Errore(err) + } + } + if applyEventStruct.dmlEvent != nil { + dmlEvents := [](*binlog.BinlogDMLEvent){} + dmlEvents = append(dmlEvents, applyEventStruct.dmlEvent) + + availableEvents := len(this.applyEventsQueue) + batchSize := int(atomic.LoadInt64(&this.migrationContext.DMLBatchSize)) + if availableEvents > batchSize { + availableEvents = batchSize + } + for i := 0; i < availableEvents; i++ { + additionalStruct := <-this.applyEventsQueue + if additionalStruct.dmlEvent == nil { + // Not a DML. We don't group this, and we don't batch any further + break + } + dmlEvents = append(dmlEvents, additionalStruct.dmlEvent) + } + // Create a task to apply the DML event; this will be execute by executeWriteFuncs() + var applyEventFunc tableWriteFunc = func() error { + return this.applier.ApplyDMLEventQueries(dmlEvents) + } + if err := this.retryOperation(applyEventFunc); err != nil { + return log.Errore(err) + } } } default: From 914f692af2888edacead61eca911455448c982da Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 4 Jan 2017 08:30:23 +0200 Subject: [PATCH 058/104] grammar --- doc/subsecond-lag.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/subsecond-lag.md b/doc/subsecond-lag.md index 564d337..5152978 100644 --- a/doc/subsecond-lag.md +++ b/doc/subsecond-lag.md @@ -14,6 +14,6 @@ In both cases, `gh-ost` uses an internal heartbeat mechanism. It injects heartbe You can explicitly define how frequently will `gh-ost` inject heartbeat events, via `heartbeat-interval-millis`. You should set `heartbeat-interval-millis <= max-lag-millis`. It still works if not, but loses granularity and effect. -In earlier versions, the `--throttle-control-replicas` list was subjected to `1` second resolution or to 3rd party heartbeat injections such as `pt-heartbeat`. This is no longer the case. The argument `--replication-lag-query` has been deprecated and no longer needed. +In earlier versions, the `--throttle-control-replicas` list was subjected to `1` second resolution or to 3rd party heartbeat injections such as `pt-heartbeat`. This is no longer the case. The argument `--replication-lag-query` has been deprecated and is no longer needed. Our production migrations use sub-second lag throttling and are able to keep our entire fleet of replicas well below `1sec` lag. We use `--heartbeat-interval-millis=100` on our production migrations with a `--max-lag-millis` value of between `300` and `500`. From 8d0faa55e3f30954dd42f45f4d81e4358254dbc2 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 4 Jan 2017 08:44:04 +0200 Subject: [PATCH 059/104] explicit rollback in ApplyDMLEventQueries() --- go/logic/applier.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/go/logic/applier.go b/go/logic/applier.go index d530d11..d1ad331 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -961,21 +961,27 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) if err != nil { return err } + + rollback := func(err error) error { + tx.Rollback() + return err + } + sessionQuery := `SET SESSION time_zone = '+00:00', sql_mode = CONCAT(@@session.sql_mode, ',STRICT_ALL_TABLES') ` if _, err := tx.Exec(sessionQuery); err != nil { - return err + return rollback(err) } for _, dmlEvent := range dmlEvents { query, args, rowDelta, err := this.buildDMLEventQuery(dmlEvent) if err != nil { - return err + return rollback(err) } if _, err := tx.Exec(query, args...); err != nil { err = fmt.Errorf("%s; query=%s; args=%+v", err.Error(), query, args) - return err + return rollback(err) } totalDelta += rowDelta } From 645af21d038138d421a059c804a6a5d0bdc8f679 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 4 Jan 2017 12:39:57 +0200 Subject: [PATCH 060/104] extracted onApplyEventStruct() --- go/logic/migrator.go | 67 ++++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 4478b0d..3d83d27 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -1024,6 +1024,40 @@ func (this *Migrator) iterateChunks() error { return nil } +func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { + if eventStruct.writeFunc != nil { + if err := this.retryOperation(*eventStruct.writeFunc); err != nil { + return log.Errore(err) + } + } + if eventStruct.dmlEvent != nil { + dmlEvents := [](*binlog.BinlogDMLEvent){} + dmlEvents = append(dmlEvents, eventStruct.dmlEvent) + + availableEvents := len(this.applyEventsQueue) + batchSize := int(atomic.LoadInt64(&this.migrationContext.DMLBatchSize)) + if availableEvents > batchSize { + availableEvents = batchSize + } + for i := 0; i < availableEvents; i++ { + additionalStruct := <-this.applyEventsQueue + if additionalStruct.dmlEvent == nil { + // Not a DML. We don't group this, and we don't batch any further + break + } + dmlEvents = append(dmlEvents, additionalStruct.dmlEvent) + } + // Create a task to apply the DML event; this will be execute by executeWriteFuncs() + var applyEventFunc tableWriteFunc = func() error { + return this.applier.ApplyDMLEventQueries(dmlEvents) + } + if err := this.retryOperation(applyEventFunc); err != nil { + return log.Errore(err) + } + } + return nil +} + // executeWriteFuncs writes data via applier: both the rowcopy and the events backlog. // This is where the ghost table gets the data. The function fills the data single-threaded. // Both event backlog and rowcopy events are polled; the backlog events have precedence. @@ -1038,38 +1072,9 @@ func (this *Migrator) executeWriteFuncs() error { // We give higher priority to event processing, then secondary priority to // rowcopy select { - case applyEventStruct := <-this.applyEventsQueue: + case eventStruct := <-this.applyEventsQueue: { - if applyEventStruct.writeFunc != nil { - if err := this.retryOperation(*applyEventStruct.writeFunc); err != nil { - return log.Errore(err) - } - } - if applyEventStruct.dmlEvent != nil { - dmlEvents := [](*binlog.BinlogDMLEvent){} - dmlEvents = append(dmlEvents, applyEventStruct.dmlEvent) - - availableEvents := len(this.applyEventsQueue) - batchSize := int(atomic.LoadInt64(&this.migrationContext.DMLBatchSize)) - if availableEvents > batchSize { - availableEvents = batchSize - } - for i := 0; i < availableEvents; i++ { - additionalStruct := <-this.applyEventsQueue - if additionalStruct.dmlEvent == nil { - // Not a DML. We don't group this, and we don't batch any further - break - } - dmlEvents = append(dmlEvents, additionalStruct.dmlEvent) - } - // Create a task to apply the DML event; this will be execute by executeWriteFuncs() - var applyEventFunc tableWriteFunc = func() error { - return this.applier.ApplyDMLEventQueries(dmlEvents) - } - if err := this.retryOperation(applyEventFunc); err != nil { - return log.Errore(err) - } - } + this.onApplyEventStruct(eventStruct) } default: { From baaa255182fa22452b9a6a33451f3913f6dc367c Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 4 Jan 2017 12:42:21 +0200 Subject: [PATCH 061/104] bailing out from onApplyEventStruct() --- go/logic/migrator.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 3d83d27..7b3daf3 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -1074,7 +1074,9 @@ func (this *Migrator) executeWriteFuncs() error { select { case eventStruct := <-this.applyEventsQueue: { - this.onApplyEventStruct(eventStruct) + if err := this.onApplyEventStruct(eventStruct); err != nil { + return err + } } default: { From f7d2beb4d28b4aa1d7b0acaf4c9adf0cd6166363 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 5 Jan 2017 08:13:51 +0200 Subject: [PATCH 062/104] handling a non-DML event at the end of a dml-event sequence --- go/logic/migrator.go | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 7b3daf3..61db6db 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -1025,14 +1025,21 @@ func (this *Migrator) iterateChunks() error { } func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { - if eventStruct.writeFunc != nil { - if err := this.retryOperation(*eventStruct.writeFunc); err != nil { - return log.Errore(err) + handleNonDMLEventStruct := func(eventStruct *applyEventStruct) error { + if eventStruct.writeFunc != nil { + if err := this.retryOperation(*eventStruct.writeFunc); err != nil { + return log.Errore(err) + } } + return nil + } + if eventStruct.dmlEvent == nil { + return handleNonDMLEventStruct(eventStruct) } if eventStruct.dmlEvent != nil { dmlEvents := [](*binlog.BinlogDMLEvent){} dmlEvents = append(dmlEvents, eventStruct.dmlEvent) + var nonDmlStructToApply *applyEventStruct availableEvents := len(this.applyEventsQueue) batchSize := int(atomic.LoadInt64(&this.migrationContext.DMLBatchSize)) @@ -1043,6 +1050,7 @@ func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { additionalStruct := <-this.applyEventsQueue if additionalStruct.dmlEvent == nil { // Not a DML. We don't group this, and we don't batch any further + nonDmlStructToApply = additionalStruct break } dmlEvents = append(dmlEvents, additionalStruct.dmlEvent) @@ -1054,6 +1062,13 @@ func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { if err := this.retryOperation(applyEventFunc); err != nil { return log.Errore(err) } + if nonDmlStructToApply != nil { + // We pulled DML events from the queue, and then we hit a non-DML event. Wait! + // We need to handle it! + if err := handleNonDMLEventStruct(nonDmlStructToApply); err != nil { + return log.Errore(err) + } + } } return nil } From 115702716146c07e824c99d5d77360a02c558d11 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 8 Jan 2017 09:46:12 +0200 Subject: [PATCH 063/104] documenting --dml-batch-size --- RELEASE_VERSION | 2 +- doc/command-line-flags.md | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/RELEASE_VERSION b/RELEASE_VERSION index 15245f3..ffcbe71 100644 --- a/RELEASE_VERSION +++ b/RELEASE_VERSION @@ -1 +1 @@ -1.0.32 +1.0.34 diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index d7bc97b..83bfe05 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -65,6 +65,17 @@ At this time (10-2016) `gh-ost` does not support foreign keys on migrated tables See also: [`skip-foreign-key-checks`](#skip-foreign-key-checks) + +### dml-batch-size + +`gh-ost` reads event from the binary log and applies them onto the _ghost_ table. It does so in batched writes: grouping multiple events to apply in a single transaction. This gives better write throughput as we don't need to sync the transaction log to disk for each event. + +The `--dml-batch-size` flag controls the size of the batched write. Allowed values are `1 - 100`, where `1` means no batching (every event from the binary log is applied onto the _ghost_ table on its own transaction). Default value is `10`. + +Why is this behavior configurable? Different workloads have different characteristics. Some workloads have very large writes, such that aggregating even `50` writes into a transaction makes for a significant transaction size. On other workloads write rate is high such that one just can't allow for a hundred more syncs to disk per second. The default value of `10` is a modest compromise that should probably work very well for most workloads. Your mileage may vary. + +Noteworthy is that setting `--dml-batch-size` to higher value _does not_ mean `gh-ost` blocks or waits on writes. The batch size is an upper limit on transaction size, not a minimal one. If `gh-ost` doesn't have "enough" events in the pipe, it does not wait on the binary log, it just writes what it already has. This conveniently suggests that if write load is light enough for `gh-ost` to only see a few events in the binary log at a given time, then it is also light neough for `gh-ost` to apply a fraction of the batch size. + ### exact-rowcount A `gh-ost` execution need to copy whatever rows you have in your existing table onto the ghost table. This can, and often be, a large number. Exactly what that number is? From 5ba98a627c2ae6d9b3ddc0bcc619cdb1d1241da8 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 9 Jan 2017 15:40:29 +0200 Subject: [PATCH 064/104] localtests for unsigned-int --- localtests/unsigned-modify/create.sql | 26 ++++++++++++++++++++++++++ localtests/unsigned-modify/extra_args | 1 + 2 files changed, 27 insertions(+) create mode 100644 localtests/unsigned-modify/create.sql create mode 100644 localtests/unsigned-modify/extra_args diff --git a/localtests/unsigned-modify/create.sql b/localtests/unsigned-modify/create.sql new file mode 100644 index 0000000..9adbfd2 --- /dev/null +++ b/localtests/unsigned-modify/create.sql @@ -0,0 +1,26 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id bigint(20) NOT NULL AUTO_INCREMENT, + column1 int(11) NOT NULL, + column2 smallint(5) unsigned NOT NULL, + column3 mediumint(8) unsigned NOT NULL, + column4 tinyint(3) unsigned NOT NULL, + column6 int(11) NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY c12_uix (column1, column2) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + -- mediumint maxvalue: 16777215 (unsigned), 8388607 (signed) + insert into gh_ost_test values (NULL, 13382498, 536, 8388607, 3, 1483892218); + insert into gh_ost_test values (NULL, 13382498, 536, 10000000, 3, 1483892218); +end ;; diff --git a/localtests/unsigned-modify/extra_args b/localtests/unsigned-modify/extra_args new file mode 100644 index 0000000..bfa1c4d --- /dev/null +++ b/localtests/unsigned-modify/extra_args @@ -0,0 +1 @@ +--alter="ADD UNIQUE INDEX (id), DROP INDEX c12_uix, DROP PRIMARY KEY,ADD PRIMARY KEY(column1,column2),MODIFY column2 SMALLINT UNSIGNED NOT NULL, MODIFY column3 INT NOT NULL" From 1921b0f9dfd45bfd71b5044c2eefd3a3800fb9bb Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 09:13:05 +0200 Subject: [PATCH 065/104] complicating the test --- localtests/unsigned-modify/create.sql | 1 + localtests/unsigned-modify/extra_args | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/localtests/unsigned-modify/create.sql b/localtests/unsigned-modify/create.sql index 9adbfd2..fd27c9f 100644 --- a/localtests/unsigned-modify/create.sql +++ b/localtests/unsigned-modify/create.sql @@ -5,6 +5,7 @@ create table gh_ost_test ( column2 smallint(5) unsigned NOT NULL, column3 mediumint(8) unsigned NOT NULL, column4 tinyint(3) unsigned NOT NULL, + column5 int(11) NOT NULL, column6 int(11) NOT NULL, PRIMARY KEY (id), UNIQUE KEY c12_uix (column1, column2) diff --git a/localtests/unsigned-modify/extra_args b/localtests/unsigned-modify/extra_args index bfa1c4d..6ae5eaf 100644 --- a/localtests/unsigned-modify/extra_args +++ b/localtests/unsigned-modify/extra_args @@ -1 +1 @@ ---alter="ADD UNIQUE INDEX (id), DROP INDEX c12_uix, DROP PRIMARY KEY,ADD PRIMARY KEY(column1,column2),MODIFY column2 SMALLINT UNSIGNED NOT NULL, MODIFY column3 INT NOT NULL" +--alter="ADD UNIQUE INDEX (id), DROP COLUMN column5, DROP INDEX c12_uix, DROP PRIMARY KEY,ADD PRIMARY KEY(column1,column2),MODIFY column2 SMALLINT UNSIGNED NOT NULL, MODIFY column3 INT NOT NULL" From 374e82c8eb12e32d14c5e96cf60f3727b6566762 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 09:17:49 +0200 Subject: [PATCH 066/104] having the test reproduce the problem for which it was created --- localtests/unsigned-modify/create.sql | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/localtests/unsigned-modify/create.sql b/localtests/unsigned-modify/create.sql index fd27c9f..9a0eb18 100644 --- a/localtests/unsigned-modify/create.sql +++ b/localtests/unsigned-modify/create.sql @@ -8,7 +8,7 @@ create table gh_ost_test ( column5 int(11) NOT NULL, column6 int(11) NOT NULL, PRIMARY KEY (id), - UNIQUE KEY c12_uix (column1, column2) + KEY c12_uix (column1, column2) ) auto_increment=1; drop event if exists gh_ost_test; @@ -22,6 +22,6 @@ create event gh_ost_test do begin -- mediumint maxvalue: 16777215 (unsigned), 8388607 (signed) - insert into gh_ost_test values (NULL, 13382498, 536, 8388607, 3, 1483892218); - insert into gh_ost_test values (NULL, 13382498, 536, 10000000, 3, 1483892218); + insert into gh_ost_test values (NULL, 13382498, 536, 8388607, 3, 1483892217, 1483892218); + insert into gh_ost_test values (NULL, 13382498, 536, 10000000, 3, 1483892217, 1483892218); end ;; From 67cfba3ce45681ce62f5a756a78f48c25a58df5d Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 09:24:57 +0200 Subject: [PATCH 067/104] simplified test --- localtests/unsigned-modify/extra_args | 1 - 1 file changed, 1 deletion(-) delete mode 100644 localtests/unsigned-modify/extra_args diff --git a/localtests/unsigned-modify/extra_args b/localtests/unsigned-modify/extra_args deleted file mode 100644 index 6ae5eaf..0000000 --- a/localtests/unsigned-modify/extra_args +++ /dev/null @@ -1 +0,0 @@ ---alter="ADD UNIQUE INDEX (id), DROP COLUMN column5, DROP INDEX c12_uix, DROP PRIMARY KEY,ADD PRIMARY KEY(column1,column2),MODIFY column2 SMALLINT UNSIGNED NOT NULL, MODIFY column3 INT NOT NULL" From a07a6f8cf5f48bdb97c622605a87a4389cbc80c6 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 09:57:42 +0200 Subject: [PATCH 068/104] fixing mediumint unsigned problem --- go/logic/inspect.go | 7 ++++++- go/sql/builder.go | 4 ++-- go/sql/types.go | 13 ++++++++++++- localtests/unsigned-modify/create.sql | 3 ++- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 6c81e99..9daef39 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -529,6 +529,11 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL columnsList.SetUnsigned(columnName) } } + if strings.Contains(columnType, "mediumint") { + for _, columnsList := range columnsLists { + columnsList.GetColumn(columnName).Type = sql.MediumIntColumnType + } + } if strings.Contains(columnType, "timestamp") { for _, columnsList := range columnsLists { columnsList.GetColumn(columnName).Type = sql.TimestampColumnType @@ -541,7 +546,7 @@ func (this *Inspector) applyColumnTypes(databaseName, tableName string, columnsL } if strings.HasPrefix(columnType, "enum") { for _, columnsList := range columnsLists { - columnsList.GetColumn(columnName).Type = sql.EnumColumnValue + columnsList.GetColumn(columnName).Type = sql.EnumColumnType } } if charset := m.GetString("CHARACTER_SET_NAME"); charset != "" { diff --git a/go/sql/builder.go b/go/sql/builder.go index c455992..0b1ac9e 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -258,7 +258,7 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) for i, column := range uniqueKeyColumns.Columns() { uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i]) - if column.Type == EnumColumnValue { + if column.Type == EnumColumnType { uniqueKeyColumnAscending[i] = fmt.Sprintf("concat(%s) asc", uniqueKeyColumnNames[i]) uniqueKeyColumnDescending[i] = fmt.Sprintf("concat(%s) desc", uniqueKeyColumnNames[i]) } else { @@ -309,7 +309,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumnNames), len(uniqueKeyColumnNames)) for i, column := range uniqueKeyColumns.Columns() { uniqueKeyColumnNames[i] = EscapeName(uniqueKeyColumnNames[i]) - if column.Type == EnumColumnValue { + if column.Type == EnumColumnType { uniqueKeyColumnOrder[i] = fmt.Sprintf("concat(%s) %s", uniqueKeyColumnNames[i], order) } else { uniqueKeyColumnOrder[i] = fmt.Sprintf("%s %s", uniqueKeyColumnNames[i], order) diff --git a/go/sql/types.go b/go/sql/types.go index 720f92f..9f4f8e7 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -18,9 +18,12 @@ const ( UnknownColumnType ColumnType = iota TimestampColumnType = iota DateTimeColumnType = iota - EnumColumnValue = iota + EnumColumnType = iota + MediumIntColumnType = iota ) +const maxMediumintUnsigned int32 = 16777215 + type TimezoneConvertion struct { ToTimezone string } @@ -50,6 +53,14 @@ func (this *Column) convertArg(arg interface{}) interface{} { return uint16(i) } if i, ok := arg.(int32); ok { + if this.Type == MediumIntColumnType { + // problem with mediumint is that it's a 3-byte type. There is no compatible golang type to match that. + // So to convert from negative to positive we'd need to convert the value manually + if i >= 0 { + return i + } + return uint32(maxMediumintUnsigned + i + 1) + } return uint32(i) } if i, ok := arg.(int64); ok { diff --git a/localtests/unsigned-modify/create.sql b/localtests/unsigned-modify/create.sql index 9a0eb18..85e7580 100644 --- a/localtests/unsigned-modify/create.sql +++ b/localtests/unsigned-modify/create.sql @@ -8,7 +8,7 @@ create table gh_ost_test ( column5 int(11) NOT NULL, column6 int(11) NOT NULL, PRIMARY KEY (id), - KEY c12_uix (column1, column2) + KEY c12_ix (column1, column2) ) auto_increment=1; drop event if exists gh_ost_test; @@ -23,5 +23,6 @@ create event gh_ost_test begin -- mediumint maxvalue: 16777215 (unsigned), 8388607 (signed) insert into gh_ost_test values (NULL, 13382498, 536, 8388607, 3, 1483892217, 1483892218); + insert into gh_ost_test values (NULL, 13382498, 536, 8388607, 250, 1483892217, 1483892218); insert into gh_ost_test values (NULL, 13382498, 536, 10000000, 3, 1483892217, 1483892218); end ;; From 115e325d7cb9d6f2eb1ac41bca82b5d0c42f6766 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 11:22:00 +0200 Subject: [PATCH 069/104] sanity replication restart before each test --- RELEASE_VERSION | 2 +- localtests/test.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/RELEASE_VERSION b/RELEASE_VERSION index ffcbe71..28dff43 100644 --- a/RELEASE_VERSION +++ b/RELEASE_VERSION @@ -1 +1 @@ -1.0.34 +1.0.35 diff --git a/localtests/test.sh b/localtests/test.sh index ab7b459..eb0274a 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -49,7 +49,7 @@ test_single() { echo -n "Testing: $test_name" echo_dot - gh-ost-test-mysql-replica -e "start slave" + gh-ost-test-mysql-replica -e "stop slave; start slave; do sleep(1)" echo_dot gh-ost-test-mysql-master --default-character-set=utf8mb4 test < $tests_path/$test_name/create.sql From 36a66e0c05c8bb2cd27efccfce5ab6ad66be0f7a Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 12:35:10 +0200 Subject: [PATCH 070/104] supporting customized 'order by' in tests --- localtests/fail-drop-pk/create.sql | 22 ++++++++++++++++++++++ localtests/fail-drop-pk/expect_failure | 1 + localtests/fail-drop-pk/extra_args | 1 + localtests/test.sh | 10 +++++++--- 4 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 localtests/fail-drop-pk/create.sql create mode 100644 localtests/fail-drop-pk/expect_failure create mode 100644 localtests/fail-drop-pk/extra_args diff --git a/localtests/fail-drop-pk/create.sql b/localtests/fail-drop-pk/create.sql new file mode 100644 index 0000000..5bb45f2 --- /dev/null +++ b/localtests/fail-drop-pk/create.sql @@ -0,0 +1,22 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + ts timestamp, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11, now()); + insert into gh_ost_test values (null, 13, now()); + insert into gh_ost_test values (null, 17, now()); +end ;; diff --git a/localtests/fail-drop-pk/expect_failure b/localtests/fail-drop-pk/expect_failure new file mode 100644 index 0000000..021ae87 --- /dev/null +++ b/localtests/fail-drop-pk/expect_failure @@ -0,0 +1 @@ +No PRIMARY nor UNIQUE key found in table diff --git a/localtests/fail-drop-pk/extra_args b/localtests/fail-drop-pk/extra_args new file mode 100644 index 0000000..2017cbc --- /dev/null +++ b/localtests/fail-drop-pk/extra_args @@ -0,0 +1 @@ +--alter="change id id int, drop primary key" diff --git a/localtests/test.sh b/localtests/test.sh index ab7b459..c0b97a5 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -49,7 +49,7 @@ test_single() { echo -n "Testing: $test_name" echo_dot - gh-ost-test-mysql-replica -e "start slave" + gh-ost-test-mysql-replica -e "stop slave; start slave; do sleep(1)" echo_dot gh-ost-test-mysql-master --default-character-set=utf8mb4 test < $tests_path/$test_name/create.sql @@ -59,12 +59,16 @@ test_single() { fi orig_columns="*" ghost_columns="*" + order_by="" if [ -f $tests_path/$test_name/orig_columns ] ; then orig_columns=$(cat $tests_path/$test_name/orig_columns) fi if [ -f $tests_path/$test_name/ghost_columns ] ; then ghost_columns=$(cat $tests_path/$test_name/ghost_columns) fi + if [ -f $tests_path/$test_name/order_by ] ; then + order_by="order by $(cat $tests_path/$test_name/order_by)" + fi # graceful sleep for replica to catch up echo_dot sleep 1 @@ -129,8 +133,8 @@ test_single() { fi echo_dot - orig_checksum=$(gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "select ${orig_columns} from gh_ost_test" -ss | md5sum) - ghost_checksum=$(gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "select ${ghost_columns} from _gh_ost_test_gho" -ss | md5sum) + orig_checksum=$(gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "select ${orig_columns} from gh_ost_test ${order_by}" -ss | md5sum) + ghost_checksum=$(gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "select ${ghost_columns} from _gh_ost_test_gho ${order_by}" -ss | md5sum) if [ "$orig_checksum" != "$ghost_checksum" ] ; then echo "ERROR $test_name: checksum mismatch" From dd9a3e1d0c7286583e7add5a0d76b16e979325fb Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 12:35:42 +0200 Subject: [PATCH 071/104] adding PK, UK, PK-to-UK conversion tests --- localtests/fail-no-shared-uk/create.sql | 22 +++++++++++++++++++ localtests/fail-no-shared-uk/expect_failure | 1 + localtests/fail-no-shared-uk/extra_args | 1 + localtests/swap-pk-uk/create.sql | 24 +++++++++++++++++++++ localtests/swap-pk-uk/extra_args | 1 + localtests/swap-pk-uk/order_by | 1 + localtests/swap-uk/create.sql | 22 +++++++++++++++++++ localtests/swap-uk/extra_args | 1 + 8 files changed, 73 insertions(+) create mode 100644 localtests/fail-no-shared-uk/create.sql create mode 100644 localtests/fail-no-shared-uk/expect_failure create mode 100644 localtests/fail-no-shared-uk/extra_args create mode 100644 localtests/swap-pk-uk/create.sql create mode 100644 localtests/swap-pk-uk/extra_args create mode 100644 localtests/swap-pk-uk/order_by create mode 100644 localtests/swap-uk/create.sql create mode 100644 localtests/swap-uk/extra_args diff --git a/localtests/fail-no-shared-uk/create.sql b/localtests/fail-no-shared-uk/create.sql new file mode 100644 index 0000000..5bb45f2 --- /dev/null +++ b/localtests/fail-no-shared-uk/create.sql @@ -0,0 +1,22 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + ts timestamp, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11, now()); + insert into gh_ost_test values (null, 13, now()); + insert into gh_ost_test values (null, 17, now()); +end ;; diff --git a/localtests/fail-no-shared-uk/expect_failure b/localtests/fail-no-shared-uk/expect_failure new file mode 100644 index 0000000..d3ef0f1 --- /dev/null +++ b/localtests/fail-no-shared-uk/expect_failure @@ -0,0 +1 @@ +No shared unique key can be found after ALTER diff --git a/localtests/fail-no-shared-uk/extra_args b/localtests/fail-no-shared-uk/extra_args new file mode 100644 index 0000000..379a77d --- /dev/null +++ b/localtests/fail-no-shared-uk/extra_args @@ -0,0 +1 @@ +--alter="drop primary key, add primary key (id, i)" diff --git a/localtests/swap-pk-uk/create.sql b/localtests/swap-pk-uk/create.sql new file mode 100644 index 0000000..aa712b8 --- /dev/null +++ b/localtests/swap-pk-uk/create.sql @@ -0,0 +1,24 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id bigint, + i int not null, + ts timestamp(6), + primary key(id), + unique key its_uidx(i, ts) +) ; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values ((unix_timestamp() << 2) + 0, 11, now(6)); + insert into gh_ost_test values ((unix_timestamp() << 2) + 1, 13, now(6)); + insert into gh_ost_test values ((unix_timestamp() << 2) + 2, 17, now(6)); + insert into gh_ost_test values ((unix_timestamp() << 2) + 3, 19, now(6)); +end ;; diff --git a/localtests/swap-pk-uk/extra_args b/localtests/swap-pk-uk/extra_args new file mode 100644 index 0000000..f4de64a --- /dev/null +++ b/localtests/swap-pk-uk/extra_args @@ -0,0 +1 @@ +--alter="drop primary key, drop key its_uidx, add primary key (i, ts), add unique key id_uidx(id)" diff --git a/localtests/swap-pk-uk/order_by b/localtests/swap-pk-uk/order_by new file mode 100644 index 0000000..074d1ee --- /dev/null +++ b/localtests/swap-pk-uk/order_by @@ -0,0 +1 @@ +id diff --git a/localtests/swap-uk/create.sql b/localtests/swap-uk/create.sql new file mode 100644 index 0000000..5bb45f2 --- /dev/null +++ b/localtests/swap-uk/create.sql @@ -0,0 +1,22 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + i int not null, + ts timestamp, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 11, now()); + insert into gh_ost_test values (null, 13, now()); + insert into gh_ost_test values (null, 17, now()); +end ;; diff --git a/localtests/swap-uk/extra_args b/localtests/swap-uk/extra_args new file mode 100644 index 0000000..ad80561 --- /dev/null +++ b/localtests/swap-uk/extra_args @@ -0,0 +1 @@ +--alter="drop primary key, add unique key(id)" From fae95fc3e126ba22cd0ddce7565c0ed196bc8aaa Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 12:43:06 +0200 Subject: [PATCH 072/104] Added test for swapping uniqye keys without pk --- localtests/swap-uk-uk/create.sql | 24 ++++++++++++++++++++++++ localtests/swap-uk-uk/extra_args | 1 + localtests/swap-uk-uk/order_by | 1 + 3 files changed, 26 insertions(+) create mode 100644 localtests/swap-uk-uk/create.sql create mode 100644 localtests/swap-uk-uk/extra_args create mode 100644 localtests/swap-uk-uk/order_by diff --git a/localtests/swap-uk-uk/create.sql b/localtests/swap-uk-uk/create.sql new file mode 100644 index 0000000..5fcbf32 --- /dev/null +++ b/localtests/swap-uk-uk/create.sql @@ -0,0 +1,24 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id bigint, + i int not null, + ts timestamp(6), + unique key id_uidx(id), + unique key its_uidx(i, ts) +) ; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values ((unix_timestamp() << 2) + 0, 11, now(6)); + insert into gh_ost_test values ((unix_timestamp() << 2) + 1, 13, now(6)); + insert into gh_ost_test values ((unix_timestamp() << 2) + 2, 17, now(6)); + insert into gh_ost_test values ((unix_timestamp() << 2) + 3, 19, now(6)); +end ;; diff --git a/localtests/swap-uk-uk/extra_args b/localtests/swap-uk-uk/extra_args new file mode 100644 index 0000000..84161a5 --- /dev/null +++ b/localtests/swap-uk-uk/extra_args @@ -0,0 +1 @@ +--alter="drop key id_uidx, drop key its_uidx, add unique key its2_uidx(i, ts), add unique key id2_uidx(id)" diff --git a/localtests/swap-uk-uk/order_by b/localtests/swap-uk-uk/order_by new file mode 100644 index 0000000..074d1ee --- /dev/null +++ b/localtests/swap-uk-uk/order_by @@ -0,0 +1 @@ +id From c5d1ef7f55cd0763a77961a26f6e32ca805b860a Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 10 Jan 2017 15:05:06 +0200 Subject: [PATCH 073/104] proper documentation for shared key requirements --- doc/requirements-and-limitations.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/requirements-and-limitations.md b/doc/requirements-and-limitations.md index 81e7226..233f7b6 100644 --- a/doc/requirements-and-limitations.md +++ b/doc/requirements-and-limitations.md @@ -31,7 +31,9 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - MySQL 5.7 `JSON` columns are not supported. They are likely to be supported shortly. - The two _before_ & _after_ tables must share some `UNIQUE KEY`. Such key would be used by `gh-ost` to iterate the table. - - As an example, if your table has a single `UNIQUE KEY` and no `PRIMARY KEY`, and you wish to replace it with a `PRIMARY KEY`, you will need two migrations: one to add the `PRIMARY KEY` (this migration will use the existing `UNIQUE KEY`), another to drop the now redundant `UNIQUE KEY` (this migration will use the `PRIMARY KEY`). + - What matters is not the key's _name_, but the columns covered by such key. + - As an example, if your table has a single `UNIQUE KEY my_unique_key(list, of, columns)` and no `PRIMARY KEY`, and you wish to replace the unique key with a `PRIMARY KEY`, you are good to go with a single migration: `--alter='DROP KEY my_unique_key, ADD PRIMARY KEY (list, of, columns)'` + - It is OK if the migrations changes type of columns within the shared key. For example, assume `id INT PRIMARY KEY`, it is OK to `--alter='MODIFY id BIGINT'`. - The chosen migration key must not include columns with `NULL` values. - `gh-ost` will do its best to pick a migration key with non-nullable columns. It will by default refuse a migration where the only possible `UNIQUE KEY` includes nullable-columns. You may override this refusal via `--allow-nullable-unique-key` but **you must** be sure there are no actual `NULL` values in those columns. Such `NULL` values would cause a data integrity problem and potentially a corrupted migration. @@ -48,4 +50,4 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - If you have en `enum` field as part of your migration key (typically the `PRIMARY KEY`), migration performance will be degraded and potentially bad. [Read more](https://github.com/github/gh-ost/pull/277#issuecomment-254811520) -- Migrating a `FEDERATED` table is unsupported and is irrelevant to the problem `gh-ost` tackles. +- Migrating a `FEDERATED` table is unsupported and is irrelevant to the problem `gh-ost` tackles. From 0d31ca8a389926618f3f78d52935c7041d2c5231 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 15 Jan 2017 09:07:55 +0200 Subject: [PATCH 074/104] Added 'shared-key' documentation --- doc/requirements-and-limitations.md | 5 +-- doc/shared-key.md | 68 +++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 doc/shared-key.md diff --git a/doc/requirements-and-limitations.md b/doc/requirements-and-limitations.md index 233f7b6..1b9fce4 100644 --- a/doc/requirements-and-limitations.md +++ b/doc/requirements-and-limitations.md @@ -30,10 +30,7 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - MySQL 5.7 `JSON` columns are not supported. They are likely to be supported shortly. -- The two _before_ & _after_ tables must share some `UNIQUE KEY`. Such key would be used by `gh-ost` to iterate the table. - - What matters is not the key's _name_, but the columns covered by such key. - - As an example, if your table has a single `UNIQUE KEY my_unique_key(list, of, columns)` and no `PRIMARY KEY`, and you wish to replace the unique key with a `PRIMARY KEY`, you are good to go with a single migration: `--alter='DROP KEY my_unique_key, ADD PRIMARY KEY (list, of, columns)'` - - It is OK if the migrations changes type of columns within the shared key. For example, assume `id INT PRIMARY KEY`, it is OK to `--alter='MODIFY id BIGINT'`. +- The two _before_ & _after_ tables must share some `UNIQUE KEY`. Such key would be used by `gh-ost` to iterate the table. See [Read more](shared-key.md) - The chosen migration key must not include columns with `NULL` values. - `gh-ost` will do its best to pick a migration key with non-nullable columns. It will by default refuse a migration where the only possible `UNIQUE KEY` includes nullable-columns. You may override this refusal via `--allow-nullable-unique-key` but **you must** be sure there are no actual `NULL` values in those columns. Such `NULL` values would cause a data integrity problem and potentially a corrupted migration. diff --git a/doc/shared-key.md b/doc/shared-key.md new file mode 100644 index 0000000..480ad97 --- /dev/null +++ b/doc/shared-key.md @@ -0,0 +1,68 @@ +# Shared key + +A requirement for a migration to run is that the two _before_ and _after_ tables have a shared unique key. This is to elaborate and illustrate on the matter. + +### Introduction + +Consider a classic, simple migration. The table is any normal: + +``` +CREATE TABLE tbl ( + id bigint unsigned not null auto_increment, + data varchar(255), + more_data int, + PRIMARY KEY(id) +) +``` + +And the migration is a simple `add column ts timestamp`. + +In such migration there is no change in indexes, and in particular no change to any unique key, and specifically no change to the `PRIMARY KEY`. To run this migration, `gh-ost` would iterate the `tbl` table using the primary key, copy rows from `tbl` to the _ghost_ table `_tbl_gho` by order of `id`, and then apply binlog events onto `_tbl_gho`. + +Applying the binlog events assumes the existence of a shared unique key. For example, an `UPDATE` statement in the binary log translate to a `REPLACE` statement which `gh-ost` applies to the _ghost_ table. Such statement expects to add or replace an existing row based on given row data. In particular, it would _replace_ an existing row if a unique key violation is met. + +So `gh-ost` correlates `tbl` and `_tbl_gho` rows using a unique key. In the above example that would be the `PRIMARY KEY`. + +### Rules + +There must be a shared set of not-null columns for which there is a unique constraint in both the original table and the migration (_ghost_) table. + +### Interpreting the rules + +The same columns must be covered by a unique key in both tables. This doesn't have to be the `PRIMARY KEY`. This doesn't have to be a key of the same name. + +Upon migration, `gh-ost` inspects both the original and _ghost_ table and attempts to find at least one such unique key (or rather, a set of columns) that is shared between the two. Typically this would just be the `PRIMARY KEY`, but sometimes you may change the `PRIMARY KEY` itself, in which case `gh-ost` will look for other options. + +`gh-ost` expects unique keys where no `NULL` values are found, i.e. all columns covered by the unique key are defined as `NOT NULL`. This is implicitly true for `PRIMARY KEY`s. If no such key can be found, `gh-ost` bails out. In the event there is no such key, but you happen to _know_ your columns have no `NULL` values even though they're `NULL`-able, you may take responsibility and pass the `--allow-nullable-unique-key`. The migration will run well as long as no `NULL` values are found in the unique key's columns. Any actual `NULL`s may corrupt the migration. + +### Examples: allowed and not allowed + +``` +create table some_table ( + id int auto_increment, + ts timestamp, + name varchar(128) not null, + owner_id int not null, + loc_id int, + primary key(id), + unique key name_uidx(name) +) +``` + +Following are examples of migrations that are _good to run_: + +- `add column i int` +- `add key owner_idx(owner_id)` +- `add unique key owner_name_idx(owner_id, name)` - though you need to make sure to not write conflicting rows while this migration runs +- `drop key name_uidx` - `primary key` is shared between the tables +- `drop primary key, add primary key(owner_id, loc_id)` - `name_uidx` is shared between the tables and is used for migration +- `change id bigint unsigned` - the `'primary key` is used. The change of type still makes the `primary key` workable. +- `drop primary key, drop key name_uidx, create primary key(name), create unique key id_uidx(id)` - swapping the two keys. `gh-ost` is still happy because `id` is still unique in both tables. So is `name`. + + +Following are examples of migrations that _cannot run_: + +- `drop primary key, drop key name_uidx` - no unique key to _ghost_ table, so clearly cannot run +- `drop primary key, drop key name_uidx, create primary key(name, owner_id)` - no shared columns to both tables. Even though `name` exists in the _ghost_ table's `primary key`, it is only part of the key and in itself does not guarantee uniqueness in the _ghost_ table. + +Also, you cannot run a migration on a table that doesn't have some form of `unique key` in the first place, such as `some_table (id int, ts timestamp)` From e020b9c756d8651b518b472b1b6d32e72aed0162 Mon Sep 17 00:00:00 2001 From: Gillian Gunson Date: Thu, 19 Jan 2017 00:17:43 -0800 Subject: [PATCH 075/104] Moved my comment suggestions here --- doc/requirements-and-limitations.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/requirements-and-limitations.md b/doc/requirements-and-limitations.md index 1b9fce4..7d246a8 100644 --- a/doc/requirements-and-limitations.md +++ b/doc/requirements-and-limitations.md @@ -30,10 +30,12 @@ The `SUPER` privilege is required for `STOP SLAVE`, `START SLAVE` operations. Th - MySQL 5.7 `JSON` columns are not supported. They are likely to be supported shortly. -- The two _before_ & _after_ tables must share some `UNIQUE KEY`. Such key would be used by `gh-ost` to iterate the table. See [Read more](shared-key.md) - -- The chosen migration key must not include columns with `NULL` values. - - `gh-ost` will do its best to pick a migration key with non-nullable columns. It will by default refuse a migration where the only possible `UNIQUE KEY` includes nullable-columns. You may override this refusal via `--allow-nullable-unique-key` but **you must** be sure there are no actual `NULL` values in those columns. Such `NULL` values would cause a data integrity problem and potentially a corrupted migration. +- The two _before_ & _after_ tables must share a `PRIMARY KEY` or other `UNIQUE KEY`. This key will be used by `gh-ost` to iterate through the table rows when copying. [Read more](shared-key.md) + - The migration key must not include columns with NULL values. This means either: + 1. The columns are `NOT NULL`, or + 2. The columns are nullable but don't contain any NULL values. + - by default, `gh-ost` will not run if the only `UNIQUE KEY` includes nullable columns. + - You may override this via `--allow-nullable-unique-key` but make sure there are no actual `NULL` values in those columns. Existing NULL values can't guarantee data integrity on the migrated table. - It is not allowed to migrate a table where another table exists with same name and different upper/lower case. - For example, you may not migrate `MyTable` if another table called `MYtable` exists in the same schema. From 779e9fdd8302dc4f2e8ae94b5bc31d9522002a8b Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 29 Jan 2017 09:25:29 +0200 Subject: [PATCH 076/104] question (?) as argument in interactive commands --- go/logic/server.go | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/go/logic/server.go b/go/logic/server.go index f616f60..b1246c2 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -126,7 +126,7 @@ func (this *Server) applyServerCommand(command string, writer *bufio.Writer) (pr if len(tokens) > 1 { arg = strings.TrimSpace(tokens[1]) } - + argIsQuestion := (arg == "?") throttleHint := "# Note: you may only throttle for as long as your binary logs are not purged\n" if err := this.hooksExecutor.onInteractiveCommand(command); err != nil { @@ -152,6 +152,7 @@ no-throttle # End forced throttling (other throttling m unpostpone # Bail out a cut-over postpone; proceed to cut-over panic # panic and quit without cleanup help # This message +- use '?' (question mark) as argument to get info rather than set. e.g. "max-load=?" will just print out current max-load. `) } case "sup": @@ -160,6 +161,10 @@ help # This message return ForcePrintStatusAndHintRule, nil case "chunk-size": { + if argIsQuestion { + fmt.Fprintf(writer, "%+v\n", atomic.LoadInt64(&this.migrationContext.ChunkSize)) + return NoPrintStatusRule, nil + } if chunkSize, err := strconv.Atoi(arg); err != nil { return NoPrintStatusRule, err } else { @@ -169,6 +174,10 @@ help # This message } case "max-lag-millis": { + if argIsQuestion { + fmt.Fprintf(writer, "%+v\n", atomic.LoadInt64(&this.migrationContext.MaxLagMillisecondsThrottleThreshold)) + return NoPrintStatusRule, nil + } if maxLagMillis, err := strconv.Atoi(arg); err != nil { return NoPrintStatusRule, err } else { @@ -182,6 +191,10 @@ help # This message } case "nice-ratio": { + if argIsQuestion { + fmt.Fprintf(writer, "%+v\n", this.migrationContext.GetNiceRatio()) + return NoPrintStatusRule, nil + } if niceRatio, err := strconv.ParseFloat(arg, 64); err != nil { return NoPrintStatusRule, err } else { @@ -191,6 +204,11 @@ help # This message } case "max-load": { + if argIsQuestion { + maxLoad := this.migrationContext.GetMaxLoad() + fmt.Fprintf(writer, "%s\n", maxLoad.String()) + return NoPrintStatusRule, nil + } if err := this.migrationContext.ReadMaxLoad(arg); err != nil { return NoPrintStatusRule, err } @@ -198,6 +216,11 @@ help # This message } case "critical-load": { + if argIsQuestion { + criticalLoad := this.migrationContext.GetCriticalLoad() + fmt.Fprintf(writer, "%s\n", criticalLoad.String()) + return NoPrintStatusRule, nil + } if err := this.migrationContext.ReadCriticalLoad(arg); err != nil { return NoPrintStatusRule, err } @@ -205,12 +228,20 @@ help # This message } case "throttle-query": { + if argIsQuestion { + fmt.Fprintf(writer, "%+v\n", this.migrationContext.GetThrottleQuery()) + return NoPrintStatusRule, nil + } this.migrationContext.SetThrottleQuery(arg) fmt.Fprintf(writer, throttleHint) return ForcePrintStatusAndHintRule, nil } case "throttle-control-replicas": { + if argIsQuestion { + fmt.Fprintf(writer, "%s\n", this.migrationContext.GetThrottleControlReplicaKeys().ToCommaDelimitedList()) + return NoPrintStatusRule, nil + } if err := this.migrationContext.ReadThrottleControlReplicaKeys(arg); err != nil { return NoPrintStatusRule, err } From 599f4323a2aceafadacb14f7b56c0324199c15f1 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 29 Jan 2017 09:25:41 +0200 Subject: [PATCH 077/104] documenting the question argument --- doc/interactive-commands.md | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/doc/interactive-commands.md b/doc/interactive-commands.md index 84c887d..c6398c5 100644 --- a/doc/interactive-commands.md +++ b/doc/interactive-commands.md @@ -26,9 +26,9 @@ Both interfaces may serve at the same time. Both respond to simple text command, - The `critical-load` format must be: `some_status=[,some_status=...]`' - For example: `Threads_running=1000,threads_connected=5000`, and you would then write/echo `critical-load=Threads_running=1000,threads_connected=5000` to the socket. - `nice-ratio=`: change _nice_ ratio: 0 for aggressive (not nice, not sleeping), positive integer `n`: - - For any `1ms` spent copying rows, spend `n*1ms` units of time sleeping. - - Examples: assume a single rows chunk copy takes `100ms` to complete. - - `nice-ratio=0.5` will cause `gh-ost` to sleep for `50ms` immediately following. + - For any `1ms` spent copying rows, spend `n*1ms` units of time sleeping. + - Examples: assume a single rows chunk copy takes `100ms` to complete. + - `nice-ratio=0.5` will cause `gh-ost` to sleep for `50ms` immediately following. - `nice-ratio=1` will cause `gh-ost` to sleep for `100ms`, effectively doubling runtime - value of `2` will effectively triple the runtime; etc. - `throttle-query`: change throttle query @@ -38,6 +38,10 @@ Both interfaces may serve at the same time. Both respond to simple text command, - `unpostpone`: at a time where `gh-ost` is postponing the [cut-over](cut-over.md) phase, instruct `gh-ost` to stop postponing and proceed immediately to cut-over. - `panic`: immediately panic and abort operation +### Querying for data + +For commands that accept an argumetn as value, pass `?` (question mark) to _get_ current value rather than _set_ a new one. + ### Examples While migration is running: @@ -63,6 +67,11 @@ $ echo "chunk-size=250" | nc -U /tmp/gh-ost.test.sample_data_0.sock # Serving on TCP port: 10001 ``` +```shell +$ echo "chunk-size=?" | nc -U /tmp/gh-ost.test.sample_data_0.sock +250 +``` + ```shell $ echo throttle | nc -U /tmp/gh-ost.test.sample_data_0.sock From 8219209c8d30540a32a007c11200a91ef6ce4319 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 29 Jan 2017 09:25:58 +0200 Subject: [PATCH 078/104] tests accept postpone flag file --- localtests/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/localtests/test.sh b/localtests/test.sh index c0b97a5..8bbcf6f 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -88,7 +88,7 @@ test_single() { --throttle-query='select timestampdiff(second, min(last_update), now()) < 5 from _gh_ost_test_ghc' \ --serve-socket-file=/tmp/gh-ost.test.sock \ --initially-drop-socket-file \ - --postpone-cut-over-flag-file="" \ + --postpone-cut-over-flag-file=/tmp/gh-ost.test.postpone.flag \ --test-on-replica \ --default-retries=1 \ --verbose \ From be1ab175c77df6cea97052130d699a5b96ee60cd Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 29 Jan 2017 09:56:25 +0200 Subject: [PATCH 079/104] status presents with '# throttle-control-replicas count:' --- go/logic/migrator.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 61db6db..05ab99a 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -761,6 +761,12 @@ func (this *Migrator) printMigrationStatusHint(writers ...io.Writer) { throttleQuery, )) } + if throttleControlReplicaKeys := this.migrationContext.GetThrottleControlReplicaKeys(); throttleControlReplicaKeys.Len() > 0 { + fmt.Fprintln(w, fmt.Sprintf("# throttle-control-replicas count: %+v", + throttleControlReplicaKeys.Len(), + )) + } + if this.migrationContext.PostponeCutOverFlagFile != "" { setIndicator := "" if base.FileExists(this.migrationContext.PostponeCutOverFlagFile) { From baee4f69f90f2c55e68a45166934ca9e768ce573 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 29 Jan 2017 10:18:39 +0200 Subject: [PATCH 080/104] fixing phantom throttle-control-replicas lag result --- go/base/context.go | 6 +++++- go/logic/throttler.go | 4 +--- go/mysql/utils.go | 8 ++++++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 6142f9c..7b196af 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -559,7 +559,11 @@ func (this *MigrationContext) GetControlReplicasLagResult() mysql.ReplicationLag func (this *MigrationContext) SetControlReplicasLagResult(lagResult *mysql.ReplicationLagResult) { this.throttleMutex.Lock() defer this.throttleMutex.Unlock() - this.controlReplicasLagResult = *lagResult + if lagResult == nil { + this.controlReplicasLagResult = *mysql.NewNoReplicationLagResult() + } else { + this.controlReplicasLagResult = *lagResult + } } func (this *MigrationContext) GetThrottleControlReplicaKeys() *mysql.InstanceKeyMap { diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 1635b5e..beb4f47 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -158,9 +158,7 @@ func (this *Throttler) collectControlReplicasLag() { // No need to read lag return } - if result := readControlReplicasLag(); result != nil { - this.migrationContext.SetControlReplicasLagResult(result) - } + this.migrationContext.SetControlReplicasLagResult(readControlReplicasLag()) } aggressiveTicker := time.Tick(100 * time.Millisecond) relaxedFactor := 10 diff --git a/go/mysql/utils.go b/go/mysql/utils.go index bfd3c24..fb70dc9 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -22,6 +22,14 @@ type ReplicationLagResult struct { Err error } +func NewNoReplicationLagResult() *ReplicationLagResult { + return &ReplicationLagResult{Lag: 0, Err: nil} +} + +func (this *ReplicationLagResult) HasLag() bool { + return this.Lag > 0 +} + // GetReplicationLag returns replication lag for a given connection config; either by explicit query // or via SHOW SLAVE STATUS func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time.Duration, err error) { From 9d47992d7ecea65a649bfc56ba2a553b3c368440 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 2 Feb 2017 11:18:07 +0200 Subject: [PATCH 081/104] testing/running on replica: gets lags via SHOW SLAVE STATUS --- go/logic/throttler.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 1635b5e..41ece40 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -84,18 +84,30 @@ func (this *Throttler) parseChangelogHeartbeat(heartbeatValue string) (err error } } -// collectHeartbeat reads the latest changelog heartbeat value -func (this *Throttler) collectHeartbeat() { +// collectReplicationLag reads the latest changelog heartbeat value +func (this *Throttler) collectReplicationLag() { ticker := time.Tick(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) for range ticker { go func() error { if atomic.LoadInt64(&this.migrationContext.CleanupImminentFlag) > 0 { return nil } - if heartbeatValue, err := this.inspector.readChangelogState("heartbeat"); err != nil { - return log.Errore(err) + + if this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica { + // when running on replica, the heartbeat injection is also done on the replica. + // This means we will always get a good heartbeat value. + // When runnign on replica, we should instead check the `SHOW SLAVE STATUS` output. + if lag, err := mysql.GetReplicationLag(this.inspector.connectionConfig); err != nil { + return log.Errore(err) + } else { + atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) + } } else { - this.parseChangelogHeartbeat(heartbeatValue) + if heartbeatValue, err := this.inspector.readChangelogState("heartbeat"); err != nil { + return log.Errore(err) + } else { + this.parseChangelogHeartbeat(heartbeatValue) + } } return nil }() @@ -114,6 +126,7 @@ func (this *Throttler) collectControlReplicasLag() { readReplicaLag := func(connectionConfig *mysql.ConnectionConfig) (lag time.Duration, err error) { dbUri := connectionConfig.GetDBUri("information_schema") + var heartbeatValue string if db, _, err := sqlutils.GetDB(dbUri); err != nil { return lag, err @@ -272,7 +285,7 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { // that may affect throttling. There are several components, all running independently, // that collect such metrics. func (this *Throttler) initiateThrottlerCollection(firstThrottlingCollected chan<- bool) { - go this.collectHeartbeat() + go this.collectReplicationLag() go this.collectControlReplicasLag() go func() { From d6347c4f33500c35c29ec83fc2edfcf45845f9ef Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 5 Feb 2017 08:03:28 +0200 Subject: [PATCH 082/104] adding common-questions doc page --- doc/questions.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 doc/questions.md diff --git a/doc/questions.md b/doc/questions.md new file mode 100644 index 0000000..33757e5 --- /dev/null +++ b/doc/questions.md @@ -0,0 +1,22 @@ +# How? + +### How does the cut-over work? Is it really atomic? + +The cut-over phase, where the original table is swapped away, and the _ghost_ table takes its place, is an atomic, blocking, controlled operation. + +- Atomic: the tables are swapped together. There is no gap where your table does not exist. +- Blocking: all app queries involving the migrated (original) table are either operate on the original table, or are blocked, or proceed to operate on the _new_ table (formerly the _ghost_ table, now swapped in). +- Controlled: the cut-over times out at pre-defined threshold, and is atomically aborted, then re-attempted. Cut-over only takes place when no lags are present, and otherwise no throttling reason is found. Cut-over step itself gets high priority and is never throttled. + +Read more on [cut-over](cut-over.md) and on the [cut-over design Issue](https://github.com/github/gh-ost/issues/82) + + +# Is it possible to? + +### Is it possible to add a UNIQUE KEY? + +Adding a `UNIQUE KEY` is possible, in the condition that no violation will occur. That is, you must make sure there aren't any violating rows on your table before, and during the migration. + +At this time there is no equivalent to `ALETER IGNORE`, where duplicates are implicitly and silently thrown away. This behavior may be supported in the future. + +# Why From ad7b700bd8d82aa74d22080531ce3733e4815ff9 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 5 Feb 2017 08:05:10 +0200 Subject: [PATCH 083/104] link to questions page --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 451b5b2..f312d2f 100644 --- a/README.md +++ b/README.md @@ -58,9 +58,10 @@ More tips: Also see: - [requirements and limitations](doc/requirements-and-limitations.md) +- [common questions](doc/questions.md) - [what if?](doc/what-if.md) - [the fine print](doc/the-fine-print.md) -- [Questions](https://github.com/github/gh-ost/issues?q=label%3Aquestion) +- [Community questions](https://github.com/github/gh-ost/issues?q=label%3Aquestion) ## What's in a name? From 6a468c4f365300637be3002fb9a21be7a83581f9 Mon Sep 17 00:00:00 2001 From: Gillian Gunson Date: Sun, 5 Feb 2017 17:28:22 -0800 Subject: [PATCH 084/104] Fixed tiny typo --- doc/questions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/questions.md b/doc/questions.md index 33757e5..7f9cb20 100644 --- a/doc/questions.md +++ b/doc/questions.md @@ -17,6 +17,6 @@ Read more on [cut-over](cut-over.md) and on the [cut-over design Issue](https:// Adding a `UNIQUE KEY` is possible, in the condition that no violation will occur. That is, you must make sure there aren't any violating rows on your table before, and during the migration. -At this time there is no equivalent to `ALETER IGNORE`, where duplicates are implicitly and silently thrown away. This behavior may be supported in the future. +At this time there is no equivalent to `ALTER IGNORE`, where duplicates are implicitly and silently thrown away. This behavior may be supported in the future. # Why From ebf7844e373d1ca136b32dd32dc7f3ea85bdfaad Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 6 Feb 2017 07:48:48 +0200 Subject: [PATCH 085/104] 5.7 deprecation of alter ignore --- doc/questions.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/questions.md b/doc/questions.md index 33757e5..f4788f1 100644 --- a/doc/questions.md +++ b/doc/questions.md @@ -17,6 +17,10 @@ Read more on [cut-over](cut-over.md) and on the [cut-over design Issue](https:// Adding a `UNIQUE KEY` is possible, in the condition that no violation will occur. That is, you must make sure there aren't any violating rows on your table before, and during the migration. -At this time there is no equivalent to `ALETER IGNORE`, where duplicates are implicitly and silently thrown away. This behavior may be supported in the future. +At this time there is no equivalent to `ALETER IGNORE`, where duplicates are implicitly and silently thrown away. The MySQL `5.7` docs say: + +> As of MySQL 5.7.4, the IGNORE clause for ALTER TABLE is removed and its use produces an error. + +It is therefore unlikely that `gh-ost` will support this behavior. # Why From 6fa32f6d54144803e54cee52e10f97478ff769a4 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 7 Feb 2017 09:31:52 +0200 Subject: [PATCH 086/104] cut-over failure on test-on-replica starts replication again --- go/logic/applier.go | 27 ++++++++++++++++++++++++-- go/logic/hooks.go | 5 +++++ go/logic/migrator.go | 46 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 69 insertions(+), 9 deletions(-) diff --git a/go/logic/applier.go b/go/logic/applier.go index d1ad331..45e92f3 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -588,11 +588,22 @@ func (this *Applier) RenameTablesRollback() (renameError error) { // and have them written to the binary log, so that we can then read them via streamer. func (this *Applier) StopSlaveIOThread() error { query := `stop /* gh-ost */ slave io_thread` - log.Infof("Stopping replication") + log.Infof("Stopping replication IO thread") if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Replication stopped") + log.Infof("Replication IO thread stopped") + return nil +} + +// StartSlaveIOThread is applicable with --test-on-replica +func (this *Applier) StartSlaveIOThread() error { + query := `start /* gh-ost */ slave io_thread` + log.Infof("Starting replication IO thread") + if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { + return err + } + log.Infof("Replication IO thread started") return nil } @@ -635,6 +646,18 @@ func (this *Applier) StopReplication() error { return nil } +// StartReplication is used by `--test-on-replica` on cut-over failure +func (this *Applier) StartReplication() error { + if err := this.StartSlaveIOThread(); err != nil { + return err + } + if err := this.StartSlaveSQLThread(); err != nil { + return err + } + log.Infof("Replication started") + return nil +} + // GetSessionLockName returns a name for the special hint session voluntary lock func (this *Applier) GetSessionLockName(sessionId int64) string { return fmt.Sprintf("gh-ost.%d.lock", sessionId) diff --git a/go/logic/hooks.go b/go/logic/hooks.go index 25d745a..89dbd65 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -30,6 +30,7 @@ const ( onFailure = "gh-ost-on-failure" onStatus = "gh-ost-on-status" onStopReplication = "gh-ost-on-stop-replication" + onStartReplication = "gh-ost-on-start-replication" ) type HooksExecutor struct { @@ -152,3 +153,7 @@ func (this *HooksExecutor) onStatus(statusMessage string) error { func (this *HooksExecutor) onStopReplication() error { return this.executeHooks(onStopReplication) } + +func (this *HooksExecutor) onStartReplication() error { + return this.executeHooks(onStartReplication) +} diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 61db6db..58b5e79 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -385,8 +385,38 @@ func (this *Migrator) ExecOnFailureHook() (err error) { return this.hooksExecutor.onFailure() } +func (this *Migrator) handleCutOverResult(cutOverError error) (err error) { + if this.migrationContext.TestOnReplica { + // We're merly testing, we don't want to keep this state. Rollback the renames as possible + this.applier.RenameTablesRollback() + } + if cutOverError == nil { + return nil + } + // Only on error: + + if this.migrationContext.TestOnReplica { + // With `--test-on-replica` we stop replication thread, and then proceed to use + // the same cut-over phase as the master would use. That means we take locks + // and swap the tables. + // The difference is that we will later swap the tables back. + if err := this.hooksExecutor.onStartReplication(); err != nil { + return log.Errore(err) + } + if this.migrationContext.TestOnReplicaSkipReplicaStop { + log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not starting replication.") + } else { + log.Debugf("testing on replica. Starting replication IO thread after cut-over failure") + if err := this.retryOperation(this.applier.StartReplication); err != nil { + return log.Errore(err) + } + } + } + return nil +} + // cutOver performs the final step of migration, based on migration -// type (on replica? bumpy? safe?) +// type (on replica? atomic? safe?) func (this *Migrator) cutOver() (err error) { if this.migrationContext.Noop { log.Debugf("Noop operation; not really swapping tables") @@ -441,18 +471,18 @@ func (this *Migrator) cutOver() (err error) { return err } } - // We're merly testing, we don't want to keep this state. Rollback the renames as possible - defer this.applier.RenameTablesRollback() - // We further proceed to do the cutover by normal means; the 'defer' above will rollback the swap } if this.migrationContext.CutOverType == base.CutOverAtomic { // Atomic solution: we use low timeout and multiple attempts. But for // each failed attempt, we throttle until replication lag is back to normal err := this.atomicCutOver() + this.handleCutOverResult(err) return err } if this.migrationContext.CutOverType == base.CutOverTwoStep { - return this.cutOverTwoStep() + err := this.cutOverTwoStep() + this.handleCutOverResult(err) + return err } return log.Fatalf("Unknown cut-over type: %d; should never get here!", this.migrationContext.CutOverType) } @@ -1043,8 +1073,10 @@ func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { availableEvents := len(this.applyEventsQueue) batchSize := int(atomic.LoadInt64(&this.migrationContext.DMLBatchSize)) - if availableEvents > batchSize { - availableEvents = batchSize + if availableEvents > batchSize-1 { + // The "- 1" is because we already consumed one event: the original event that led to this function getting called. + // So, if DMLBatchSize==1 we wish to not process any further events + availableEvents = batchSize - 1 } for i := 0; i < availableEvents; i++ { additionalStruct := <-this.applyEventsQueue From ed5dd7f0499bbb8b303a8a124b139b023c3583eb Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 7 Feb 2017 09:41:33 +0200 Subject: [PATCH 087/104] Collecting and presenting MySQL versions of applier and inspector --- go/base/context.go | 2 ++ go/logic/applier.go | 5 +++-- go/logic/inspect.go | 5 +++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 6142f9c..55c858a 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -135,7 +135,9 @@ type MigrationContext struct { OriginalBinlogFormat string OriginalBinlogRowImage string InspectorConnectionConfig *mysql.ConnectionConfig + InspectorMySQLVersion string ApplierConnectionConfig *mysql.ConnectionConfig + ApplierMySQLVersion string StartTime time.Time RowCopyStartTime time.Time RowCopyEndTime time.Time diff --git a/go/logic/applier.go b/go/logic/applier.go index d1ad331..d1d7970 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -70,14 +70,15 @@ func (this *Applier) InitDBConnections() (err error) { if err := this.readTableColumns(); err != nil { return err } + log.Infof("Applier initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.ApplierMySQLVersion) return nil } // validateConnection issues a simple can-connect to MySQL func (this *Applier) validateConnection(db *gosql.DB) error { - query := `select @@global.port` + query := `select @@global.port, @@global.version` var port int - if err := db.QueryRow(query).Scan(&port); err != nil { + if err := db.QueryRow(query).Scan(&port, &this.migrationContext.ApplierMySQLVersion); err != nil { return err } if port != this.connectionConfig.Key.Port { diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 9daef39..862be16 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -60,6 +60,7 @@ func (this *Inspector) InitDBConnections() (err error) { if err := this.applyBinlogFormat(); err != nil { return err } + log.Infof("Inspector initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.InspectorMySQLVersion) return nil } @@ -168,9 +169,9 @@ func (this *Inspector) inspectOriginalAndGhostTables() (err error) { // validateConnection issues a simple can-connect to MySQL func (this *Inspector) validateConnection() error { - query := `select @@global.port` + query := `select @@global.port, @@global.version` var port int - if err := this.db.QueryRow(query).Scan(&port); err != nil { + if err := this.db.QueryRow(query).Scan(&port, &this.migrationContext.InspectorMySQLVersion); err != nil { return err } if port != this.connectionConfig.Key.Port { From 10edf3c063691c320df4bb2b0d7cddd2a7480576 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 7 Feb 2017 12:13:19 +0200 Subject: [PATCH 088/104] Migration only starting after first replication lag metric collected --- go/logic/migrator.go | 4 ++- go/logic/throttler.go | 58 ++++++++++++++++++++++++------------------- go/mysql/utils.go | 4 ++- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 61db6db..5880c99 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -940,7 +940,9 @@ func (this *Migrator) initiateThrottler() error { go this.throttler.initiateThrottlerCollection(this.firstThrottlingCollected) log.Infof("Waiting for first throttle metrics to be collected") - <-this.firstThrottlingCollected + <-this.firstThrottlingCollected // replication lag + <-this.firstThrottlingCollected // other metrics + log.Infof("First throttle metrics collected") go this.throttler.initiateThrottlerChecks() return nil diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 41ece40..2430c13 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -85,32 +85,37 @@ func (this *Throttler) parseChangelogHeartbeat(heartbeatValue string) (err error } // collectReplicationLag reads the latest changelog heartbeat value -func (this *Throttler) collectReplicationLag() { +func (this *Throttler) collectReplicationLag(firstThrottlingCollected chan<- bool) { + collectFunc := func() error { + if atomic.LoadInt64(&this.migrationContext.CleanupImminentFlag) > 0 { + return nil + } + + if this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica { + // when running on replica, the heartbeat injection is also done on the replica. + // This means we will always get a good heartbeat value. + // When runnign on replica, we should instead check the `SHOW SLAVE STATUS` output. + if lag, err := mysql.GetReplicationLag(this.inspector.connectionConfig); err != nil { + return log.Errore(err) + } else { + atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) + } + } else { + if heartbeatValue, err := this.inspector.readChangelogState("heartbeat"); err != nil { + return log.Errore(err) + } else { + this.parseChangelogHeartbeat(heartbeatValue) + } + } + return nil + } + + collectFunc() + firstThrottlingCollected <- true + ticker := time.Tick(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) for range ticker { - go func() error { - if atomic.LoadInt64(&this.migrationContext.CleanupImminentFlag) > 0 { - return nil - } - - if this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica { - // when running on replica, the heartbeat injection is also done on the replica. - // This means we will always get a good heartbeat value. - // When runnign on replica, we should instead check the `SHOW SLAVE STATUS` output. - if lag, err := mysql.GetReplicationLag(this.inspector.connectionConfig); err != nil { - return log.Errore(err) - } else { - atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) - } - } else { - if heartbeatValue, err := this.inspector.readChangelogState("heartbeat"); err != nil { - return log.Errore(err) - } else { - this.parseChangelogHeartbeat(heartbeatValue) - } - } - return nil - }() + go collectFunc() } } @@ -285,13 +290,14 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { // that may affect throttling. There are several components, all running independently, // that collect such metrics. func (this *Throttler) initiateThrottlerCollection(firstThrottlingCollected chan<- bool) { - go this.collectReplicationLag() + go this.collectReplicationLag(firstThrottlingCollected) go this.collectControlReplicasLag() go func() { - throttlerMetricsTick := time.Tick(1 * time.Second) this.collectGeneralThrottleMetrics() firstThrottlingCollected <- true + + throttlerMetricsTick := time.Tick(1 * time.Second) for range throttlerMetricsTick { this.collectGeneralThrottleMetrics() } diff --git a/go/mysql/utils.go b/go/mysql/utils.go index bfd3c24..80bcf1b 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -32,9 +32,11 @@ func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time. } err = sqlutils.QueryRowsMap(db, `show slave status`, func(m sqlutils.RowMap) error { + slaveIORunning := m.GetString("Slave_IO_Running") + slaveSQLRunning := m.GetString("Slave_SQL_Running") secondsBehindMaster := m.GetNullInt64("Seconds_Behind_Master") if !secondsBehindMaster.Valid { - return fmt.Errorf("replication not running") + return fmt.Errorf("replication not running; Slave_IO_Running=%+v, Slave_SQL_Running=", slaveIORunning, slaveSQLRunning) } replicationLag = time.Duration(secondsBehindMaster.Int64) * time.Second return nil From f710ca6772d98eeb897aaa339f61d7c55e42e427 Mon Sep 17 00:00:00 2001 From: Ivan Groenewold Date: Tue, 7 Feb 2017 10:56:28 -0300 Subject: [PATCH 089/104] additional info re: Tungsten --- doc/cheatsheet.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/cheatsheet.md b/doc/cheatsheet.md index 223f47c..9db8467 100644 --- a/doc/cheatsheet.md +++ b/doc/cheatsheet.md @@ -146,8 +146,12 @@ gh-ost --allow-master-master --assume-master-host=a.specific.master.com Topologies using _tungsten replicator_ are peculiar in that the participating servers are not actually aware they are replicating. The _tungsten replicator_ looks just like another app issuing queries on those hosts. `gh-ost` is unable to identify that a server participates in a _tungsten_ topology. -If you choose to migrate directly on master (see above), there's nothing special you need to do. If you choose to migrate via replica, then you must supply the identity of the master, and indicate this is a tungsten setup, as follows: +If you choose to migrate directly on master (see above), there's nothing special you need to do. + +If you choose to migrate via replica, then you need to make sure Tungsten is configured with log-slave-updates parameter (note this is different from MySQL's own log-slave-updates parameter), otherwise changes will not be in the replica's binlog, causing data to be corrupted after table swap. You must also supply the identity of the master, and indicate this is a tungsten setup, as follows: ``` gh-ost --tungsten --assume-master-host=the.topology.master.com ``` + +Also note that `--switch-to-rbr` does not work for a Tungsten setup as the replication process is external, so you need to make sure `binlog_format` is set to ROW beforehand. From fbf88fd57127d05a8d1f449c6918b2d3a1e23f72 Mon Sep 17 00:00:00 2001 From: Ivan Groenewold Date: Tue, 7 Feb 2017 11:45:20 -0300 Subject: [PATCH 090/104] additional note for Tungsten and rbr --- doc/cheatsheet.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/cheatsheet.md b/doc/cheatsheet.md index 9db8467..6ff9682 100644 --- a/doc/cheatsheet.md +++ b/doc/cheatsheet.md @@ -154,4 +154,4 @@ If you choose to migrate via replica, then you need to make sure Tungsten is con gh-ost --tungsten --assume-master-host=the.topology.master.com ``` -Also note that `--switch-to-rbr` does not work for a Tungsten setup as the replication process is external, so you need to make sure `binlog_format` is set to ROW beforehand. +Also note that `--switch-to-rbr` does not work for a Tungsten setup as the replication process is external, so you need to make sure `binlog_format` is set to ROW before Tungsten Replicator connects to the server and starts applying events from the master. From 3511aa35dfc0c3f14bcda151620ea14ee13a5b87 Mon Sep 17 00:00:00 2001 From: Ivan Groenewold Date: Tue, 7 Feb 2017 15:12:34 -0300 Subject: [PATCH 091/104] improve output without --verbose option --- go/cmd/gh-ost/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index f9ce1d7..e06014c 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -242,5 +242,5 @@ func main() { migrator.ExecOnFailureHook() log.Fatale(err) } - log.Info("Done") + fmt.Fprintf(os.Stdout, "# Done\n") } From 57409c21986dce4ed6c026ddec60470d84ecc247 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Wed, 8 Feb 2017 12:24:44 +0200 Subject: [PATCH 092/104] added mising slaveSQLRunning value --- go/mysql/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/mysql/utils.go b/go/mysql/utils.go index 80bcf1b..cb3ae36 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -36,7 +36,7 @@ func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time. slaveSQLRunning := m.GetString("Slave_SQL_Running") secondsBehindMaster := m.GetNullInt64("Seconds_Behind_Master") if !secondsBehindMaster.Valid { - return fmt.Errorf("replication not running; Slave_IO_Running=%+v, Slave_SQL_Running=", slaveIORunning, slaveSQLRunning) + return fmt.Errorf("replication not running; Slave_IO_Running=%+v, Slave_SQL_Running=%+v", slaveIORunning, slaveSQLRunning) } replicationLag = time.Duration(secondsBehindMaster.Int64) * time.Second return nil From 1acbca105d449a3b2a48d9a43ea127b28fcd2256 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 9 Feb 2017 09:09:35 +0200 Subject: [PATCH 093/104] applied gofmt --- go/cmd/gh-ost/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index e06014c..8e5e20c 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -242,5 +242,5 @@ func main() { migrator.ExecOnFailureHook() log.Fatale(err) } - fmt.Fprintf(os.Stdout, "# Done\n") + fmt.Fprintf(os.Stdout, "# Done\n") } From f6af07ff7a12ae2df23c270cd8390eb1cc7327fc Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 12 Feb 2017 13:13:54 +0200 Subject: [PATCH 094/104] updated go-mysql library 2017-02 --- .../github.com/siddontang/go-mysql/.gitignore | 3 + .../siddontang/go-mysql/.travis.yml | 32 + vendor/github.com/siddontang/go-mysql/LICENSE | 20 + .../github.com/siddontang/go-mysql/Makefile | 33 + .../github.com/siddontang/go-mysql/README.md | 229 +++ .../vendor/github.com/BurntSushi/toml/COPYING | 14 + .../github.com/BurntSushi/toml/decode.go | 492 +++++++ .../github.com/BurntSushi/toml/decode_meta.go | 122 ++ .../vendor/github.com/BurntSushi/toml/doc.go | 27 + .../github.com/BurntSushi/toml/encode.go | 551 ++++++++ .../BurntSushi/toml/encoding_types.go | 19 + .../BurntSushi/toml/encoding_types_1.1.go | 18 + .../vendor/github.com/BurntSushi/toml/lex.go | 874 ++++++++++++ .../github.com/BurntSushi/toml/parse.go | 498 +++++++ .../github.com/BurntSushi/toml/type_check.go | 91 ++ .../github.com/BurntSushi/toml/type_fields.go | 241 ++++ .../github.com/go-sql-driver/mysql/LICENSE | 373 +++++ .../go-sql-driver/mysql/appengine.go | 19 + .../github.com/go-sql-driver/mysql/buffer.go | 147 ++ .../go-sql-driver/mysql/collations.go | 250 ++++ .../go-sql-driver/mysql/connection.go | 372 +++++ .../github.com/go-sql-driver/mysql/const.go | 163 +++ .../github.com/go-sql-driver/mysql/driver.go | 167 +++ .../github.com/go-sql-driver/mysql/dsn.go | 513 +++++++ .../github.com/go-sql-driver/mysql/errors.go | 131 ++ .../github.com/go-sql-driver/mysql/infile.go | 181 +++ .../github.com/go-sql-driver/mysql/packets.go | 1243 +++++++++++++++++ .../github.com/go-sql-driver/mysql/result.go | 22 + .../github.com/go-sql-driver/mysql/rows.go | 112 ++ .../go-sql-driver/mysql/statement.go | 150 ++ .../go-sql-driver/mysql/transaction.go | 31 + .../github.com/go-sql-driver/mysql/utils.go | 740 ++++++++++ .../vendor/github.com/jmoiron/sqlx/LICENSE | 23 + .../vendor/github.com/jmoiron/sqlx/bind.go | 186 +++ .../vendor/github.com/jmoiron/sqlx/doc.go | 12 + .../vendor/github.com/jmoiron/sqlx/named.go | 336 +++++ .../jmoiron/sqlx/reflectx/reflect.go | 371 +++++ .../vendor/github.com/jmoiron/sqlx/sqlx.go | 992 +++++++++++++ .../vendor/github.com/juju/errors/LICENSE | 191 +++ .../vendor/github.com/juju/errors/doc.go | 81 ++ .../vendor/github.com/juju/errors/error.go | 145 ++ .../github.com/juju/errors/errortypes.go | 284 ++++ .../github.com/juju/errors/functions.go | 330 +++++ .../vendor/github.com/juju/errors/path.go | 38 + .../vendor/github.com/ngaut/log/LICENSE | 165 +++ .../vendor/github.com/ngaut/log/crash_unix.go | 18 + .../vendor/github.com/ngaut/log/crash_win.go | 37 + .../vendor/github.com/ngaut/log/log.go | 380 +++++ .../github.com/pingcap/check/benchmark.go | 187 +++ .../vendor/github.com/pingcap/check/check.go | 954 +++++++++++++ .../github.com/pingcap/check/checkers.go | 458 ++++++ .../github.com/pingcap/check/checkers2.go | 131 ++ .../github.com/pingcap/check/compare.go | 161 +++ .../github.com/pingcap/check/helpers.go | 231 +++ .../github.com/pingcap/check/printer.go | 168 +++ .../vendor/github.com/pingcap/check/run.go | 175 +++ .../vendor/github.com/satori/go.uuid/LICENSE | 20 + .../vendor/github.com/satori/go.uuid/uuid.go | 488 +++++++ .../vendor/github.com/siddontang/go/LICENSE | 20 + .../github.com/siddontang/go/hack/hack.go | 27 + .../siddontang/go/ioutil2/ioutil.go | 39 + .../siddontang/go/ioutil2/sectionwriter.go | 69 + .../github.com/siddontang/go/sync2/atomic.go | 146 ++ .../siddontang/go/sync2/semaphore.go | 65 + .../_vendor/vendor/golang.org/x/net/LICENSE | 27 + .../_vendor/vendor/golang.org/x/net/PATENTS | 22 + .../golang.org/x/net/context/context.go | 447 ++++++ .../siddontang/go-mysql/canal/canal.go | 307 ++++ .../siddontang/go-mysql/canal/canal_test.go | 94 ++ .../siddontang/go-mysql/canal/config.go | 79 ++ .../siddontang/go-mysql/canal/dump.go | 120 ++ .../siddontang/go-mysql/canal/handler.go | 41 + .../siddontang/go-mysql/canal/master.go | 89 ++ .../siddontang/go-mysql/canal/rows.go | 59 + .../siddontang/go-mysql/canal/sync.go | 137 ++ .../siddontang/go-mysql/client/auth.go | 39 +- .../siddontang/go-mysql/client/client_test.go | 27 +- .../siddontang/go-mysql/client/conn.go | 16 +- .../go-mysql/cmd/go-binlogparser/main.go | 28 + .../siddontang/go-mysql/cmd/go-canal/main.go | 101 ++ .../go-mysql/cmd/go-mysqlbinlog/main.go | 76 + .../go-mysql/cmd/go-mysqldump/main.go | 66 + .../siddontang/go-mysql/docker/Makefile | 53 + .../siddontang/go-mysql/driver/dirver_test.go | 73 + .../siddontang/go-mysql/driver/driver.go | 230 +++ .../siddontang/go-mysql/dump/dump.go | 161 +++ .../siddontang/go-mysql/dump/dump_test.go | 198 +++ .../siddontang/go-mysql/dump/parser.go | 188 +++ .../siddontang/go-mysql/failover/const.go | 11 + .../siddontang/go-mysql/failover/doc.go | 8 + .../siddontang/go-mysql/failover/failover.go | 68 + .../go-mysql/failover/failover_test.go | 176 +++ .../siddontang/go-mysql/failover/handler.go | 22 + .../go-mysql/failover/mariadb_gtid_handler.go | 142 ++ .../go-mysql/failover/mysql_gtid_handler.go | 141 ++ .../siddontang/go-mysql/failover/server.go | 174 +++ .../github.com/siddontang/go-mysql/glide.lock | 30 + .../github.com/siddontang/go-mysql/glide.yaml | 26 + .../siddontang/go-mysql/mysql/const.go | 3 +- .../siddontang/go-mysql/mysql/mysql_test.go | 18 +- .../siddontang/go-mysql/mysql/resultset.go | 58 +- .../siddontang/go-mysql/mysql/util.go | 16 + .../siddontang/go-mysql/replication/backup.go | 19 +- .../go-mysql/replication/binlogstreamer.go | 35 +- .../go-mysql/replication/binlogsyncer.go | 415 +++--- .../siddontang/go-mysql/replication/bitmap.go | 38 - .../go-mysql/replication/json_binary.go | 479 +++++++ .../siddontang/go-mysql/replication/parser.go | 22 +- .../go-mysql/replication/parser_test.go | 2 +- .../go-mysql/replication/replication_test.go | 185 ++- .../go-mysql/replication/row_event.go | 129 +- .../go-mysql/replication/row_event_test.go | 105 +- .../siddontang/go-mysql/schema/schema.go | 215 +++ .../siddontang/go-mysql/schema/schema_test.go | 84 ++ .../siddontang/go-mysql/server/auth.go | 119 ++ .../siddontang/go-mysql/server/command.go | 151 ++ .../siddontang/go-mysql/server/conn.go | 113 ++ .../siddontang/go-mysql/server/resp.go | 142 ++ .../siddontang/go-mysql/server/server_test.go | 230 +++ .../siddontang/go-mysql/server/stmt.go | 363 +++++ .../siddontang/go-mysql/vitess_license | 28 + 121 files changed, 20868 insertions(+), 383 deletions(-) create mode 100644 vendor/github.com/siddontang/go-mysql/.gitignore create mode 100644 vendor/github.com/siddontang/go-mysql/.travis.yml create mode 100644 vendor/github.com/siddontang/go-mysql/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/Makefile create mode 100644 vendor/github.com/siddontang/go-mysql/README.md create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/COPYING create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode_meta.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/doc.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encode.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/lex.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/parse.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_check.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_fields.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/appengine.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/buffer.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/collations.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/connection.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/const.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/driver.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/dsn.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/errors.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/infile.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/packets.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/result.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/rows.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/statement.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/transaction.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/utils.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/bind.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/doc.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/named.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/sqlx.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/doc.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/error.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/errortypes.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/functions.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/path.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_unix.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_win.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/log.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/benchmark.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/check.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers2.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/compare.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/helpers.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/printer.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/run.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/uuid.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/hack/hack.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/ioutil.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/sectionwriter.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/atomic.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/semaphore.go create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/LICENSE create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/PATENTS create mode 100644 vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/context/context.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/canal.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/canal_test.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/config.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/dump.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/handler.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/master.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/rows.go create mode 100644 vendor/github.com/siddontang/go-mysql/canal/sync.go create mode 100644 vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go create mode 100644 vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go create mode 100644 vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go create mode 100644 vendor/github.com/siddontang/go-mysql/cmd/go-mysqldump/main.go create mode 100644 vendor/github.com/siddontang/go-mysql/docker/Makefile create mode 100644 vendor/github.com/siddontang/go-mysql/driver/dirver_test.go create mode 100644 vendor/github.com/siddontang/go-mysql/driver/driver.go create mode 100644 vendor/github.com/siddontang/go-mysql/dump/dump.go create mode 100644 vendor/github.com/siddontang/go-mysql/dump/dump_test.go create mode 100644 vendor/github.com/siddontang/go-mysql/dump/parser.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/const.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/doc.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/failover.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/failover_test.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/handler.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/mysql_gtid_handler.go create mode 100644 vendor/github.com/siddontang/go-mysql/failover/server.go create mode 100644 vendor/github.com/siddontang/go-mysql/glide.lock create mode 100644 vendor/github.com/siddontang/go-mysql/glide.yaml delete mode 100644 vendor/github.com/siddontang/go-mysql/replication/bitmap.go create mode 100644 vendor/github.com/siddontang/go-mysql/replication/json_binary.go create mode 100644 vendor/github.com/siddontang/go-mysql/schema/schema.go create mode 100644 vendor/github.com/siddontang/go-mysql/schema/schema_test.go create mode 100644 vendor/github.com/siddontang/go-mysql/server/auth.go create mode 100644 vendor/github.com/siddontang/go-mysql/server/command.go create mode 100644 vendor/github.com/siddontang/go-mysql/server/conn.go create mode 100644 vendor/github.com/siddontang/go-mysql/server/resp.go create mode 100644 vendor/github.com/siddontang/go-mysql/server/server_test.go create mode 100644 vendor/github.com/siddontang/go-mysql/server/stmt.go create mode 100644 vendor/github.com/siddontang/go-mysql/vitess_license diff --git a/vendor/github.com/siddontang/go-mysql/.gitignore b/vendor/github.com/siddontang/go-mysql/.gitignore new file mode 100644 index 0000000..68da35e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/.gitignore @@ -0,0 +1,3 @@ +var +bin +.idea \ No newline at end of file diff --git a/vendor/github.com/siddontang/go-mysql/.travis.yml b/vendor/github.com/siddontang/go-mysql/.travis.yml new file mode 100644 index 0000000..cc0db3c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/.travis.yml @@ -0,0 +1,32 @@ +language: go + +go: + - 1.6 + - 1.7 + +dist: trusty +sudo: required +addons: + apt: + packages: + - mysql-server-5.6 + - mysql-client-core-5.6 + - mysql-client-5.6 + +before_script: + # stop mysql and use row-based format binlog + - "sudo /etc/init.d/mysql stop || true" + - "echo '[mysqld]' | sudo tee /etc/mysql/conf.d/replication.cnf" + - "echo 'server-id=1' | sudo tee -a /etc/mysql/conf.d/replication.cnf" + - "echo 'log-bin=mysql' | sudo tee -a /etc/mysql/conf.d/replication.cnf" + - "echo 'binlog-format = row' | sudo tee -a /etc/mysql/conf.d/replication.cnf" + + # Start mysql (avoid errors to have logs) + - "sudo /etc/init.d/mysql start || true" + - "sudo tail -1000 /var/log/syslog" + + - mysql -e "CREATE DATABASE IF NOT EXISTS test;" -uroot + + +script: + - make test \ No newline at end of file diff --git a/vendor/github.com/siddontang/go-mysql/LICENSE b/vendor/github.com/siddontang/go-mysql/LICENSE new file mode 100644 index 0000000..80511a0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014 siddontang + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/siddontang/go-mysql/Makefile b/vendor/github.com/siddontang/go-mysql/Makefile new file mode 100644 index 0000000..92744b1 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/Makefile @@ -0,0 +1,33 @@ +all: build + +build: + rm -rf vendor && ln -s _vendor/vendor vendor + go build -o bin/go-mysqlbinlog cmd/go-mysqlbinlog/main.go + go build -o bin/go-mysqldump cmd/go-mysqldump/main.go + go build -o bin/go-canal cmd/go-canal/main.go + go build -o bin/go-binlogparser cmd/go-binlogparser/main.go + rm -rf vendor + +test: + rm -rf vendor && ln -s _vendor/vendor vendor + go test --race -timeout 2m ./... + rm -rf vendor + +clean: + go clean -i ./... + @rm -rf ./bin + +update_vendor: + which glide >/dev/null || curl https://glide.sh/get | sh + which glide-vc || go get -v -u github.com/sgotti/glide-vc + rm -r vendor && mv _vendor/vendor vendor || true + rm -rf _vendor +ifdef PKG + glide get --strip-vendor --skip-test ${PKG} +else + glide update --strip-vendor --skip-test +endif + @echo "removing test files" + glide vc --only-code --no-tests + mkdir -p _vendor + mv vendor _vendor/vendor diff --git a/vendor/github.com/siddontang/go-mysql/README.md b/vendor/github.com/siddontang/go-mysql/README.md new file mode 100644 index 0000000..4ae6697 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/README.md @@ -0,0 +1,229 @@ +# go-mysql + +A pure go library to handle MySQL network protocol and replication. + +## Replication + +Replication package handles MySQL replication protocol like [python-mysql-replication](https://github.com/noplay/python-mysql-replication). + +You can use it as a MySQL slave to sync binlog from master then do something, like updating cache, etc... + +### Example + +```go +import ( + "github.com/siddontang/go-mysql/replication" + "os" +) +// Create a binlog syncer with a unique server id, the server id must be different from other MySQL's. +// flavor is mysql or mariadb +cfg := replication.BinlogSyncerConfig { + ServerID: 100, + Flavor: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Password: "", +} +syncer := replication.NewBinlogSyncer(&cfg) + +// Start sync with sepcified binlog file and position +streamer, _ := syncer.StartSync(mysql.Position{binlogFile, binlogPos}) + +// or you can start a gtid replication like +// streamer, _ := syncer.StartSyncGTID(gtidSet) +// the mysql GTID set likes this "de278ad0-2106-11e4-9f8e-6edd0ca20947:1-2" +// the mariadb GTID set likes this "0-1-100" + +for { + ev, _ := streamer.GetEvent(context.Background()) + // Dump event + ev.Dump(os.Stdout) +} + +// or we can use a timeout context +for { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + e, _ := s.GetEvent(ctx) + cancel() + + if err == context.DeadlineExceeded { + // meet timeout + continue + } + + ev.Dump(os.Stdout) +} +``` + +The output looks: + +``` +=== RotateEvent === +Date: 1970-01-01 08:00:00 +Log position: 0 +Event size: 43 +Position: 4 +Next log name: mysql.000002 + +=== FormatDescriptionEvent === +Date: 2014-12-18 16:36:09 +Log position: 120 +Event size: 116 +Version: 4 +Server version: 5.6.19-log +Create date: 2014-12-18 16:36:09 + +=== QueryEvent === +Date: 2014-12-18 16:38:24 +Log position: 259 +Event size: 139 +Salve proxy ID: 1 +Execution time: 0 +Error code: 0 +Schema: test +Query: DROP TABLE IF EXISTS `test_replication` /* generated by server */ +``` + +## Canal + +Canal is a package that can sync your MySQL into everywhere, like Redis, Elasticsearch. + +First, canal will dump your MySQL data then sync changed data using binlog incrementally. + +You must use ROW format for binlog, full binlog row image is preferred, because we may meet some errors when primary key changed in update for minimal or noblob row image. + +A simple example: + +``` +cfg := NewDefaultConfig() +cfg.Addr = "127.0.0.1:3306" +cfg.User = "root" +// We only care table canal_test in test db +cfg.Dump.TableDB = "test" +cfg.Dump.Tables = []string{"canal_test"} + +c, err := NewCanal(cfg) + +type myRowsEventHandler struct { +} + +func (h *myRowsEventHandler) Do(e *RowsEvent) error { + log.Infof("%s %v\n", e.Action, e.Rows) + return nil +} + +func (h *myRowsEventHandler) String() string { + return "myRowsEventHandler" +} + +// Register a handler to handle RowsEvent +c.RegRowsEventHandler(&MyRowsEventHandler{}) + +// Start canal +c.Start() +``` + +You can see [go-mysql-elasticsearch](https://github.com/siddontang/go-mysql-elasticsearch) for how to sync MySQL data into Elasticsearch. + +## Client + +Client package supports a simple MySQL connection driver which you can use it to communicate with MySQL server. + +### Example + +```go +import ( + "github.com/siddontang/go-mysql/client" +) + +// Connect MySQL at 127.0.0.1:3306, with user root, an empty passowrd and database test +conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test") + +conn.Ping() + +// Insert +r, _ := conn.Execute(`insert into table (id, name) values (1, "abc")`) + +// Get last insert id +println(r.InsertId) + +// Select +r, _ := conn.Execute(`select id, name from table where id = 1`) + +// Handle resultset +v, _ := r.GetInt(0, 0) +v, _ = r.GetIntByName(0, "id") +``` + +## Server + +Server package supplies a framework to implement a simple MySQL server which can handle the packets from the MySQL client. +You can use it to build your own MySQL proxy. + +### Example + +```go +import ( + "github.com/siddontang/go-mysql/server" + "net" +) + +l, _ := net.Listen("tcp", "127.0.0.1:4000") + +c, _ := l.Accept() + +// Create a connection with user root and an empty passowrd +// We only an empty handler to handle command too +conn, _ := server.NewConn(c, "root", "", server.EmptyHandler{}) + +for { + conn.HandleCommand() +} +``` + +Another shell + +``` +mysql -h127.0.0.1 -P4000 -uroot -p +//Becuase empty handler does nothing, so here the MySQL client can only connect the proxy server. :-) +``` + +## Failover + +Failover supports to promote a new master and let other slaves replicate from it automatically when the old master was down. + +Failover supports MySQL >= 5.6.9 with GTID mode, if you use lower version, e.g, MySQL 5.0 - 5.5, please use [MHA](http://code.google.com/p/mysql-master-ha/) or [orchestrator](https://github.com/outbrain/orchestrator). + +At the same time, Failover supports MariaDB >= 10.0.9 with GTID mode too. + +Why only GTID? Supporting failover with no GTID mode is very hard, because slave can not find the proper binlog filename and position with the new master. +Although there are many companies use MySQL 5.0 - 5.5, I think upgrade MySQL to 5.6 or higher is easy. + +## Driver + +Driver is the package that you can use go-mysql with go database/sql like other drivers. A simple example: + +``` +import ( + "database/sql" + + - "github.com/siddontang/go-mysql/driver" +) + +func main() { + // dsn format: "user:password@addr?dbname" + dsn := "root@127.0.0.1:3306?test" + db, _ := sql.Open(dsn) + db.Close() +} +``` + +We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) + +## Feedback + +go-mysql is still in development, your feedback is very welcome. + + +Gmail: siddontang@gmail.com diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/COPYING b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/COPYING new file mode 100644 index 0000000..5a8e332 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode.go new file mode 100644 index 0000000..6c7d398 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode.go @@ -0,0 +1,492 @@ +package toml + +import ( + "fmt" + "io" + "io/ioutil" + "math" + "reflect" + "strings" + "time" +) + +var e = fmt.Errorf + +// Unmarshaler is the interface implemented by objects that can unmarshal a +// TOML description of themselves. +type Unmarshaler interface { + UnmarshalTOML(interface{}) error +} + +// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`. +func Unmarshal(p []byte, v interface{}) error { + _, err := Decode(string(p), v) + return err +} + +// Primitive is a TOML value that hasn't been decoded into a Go value. +// When using the various `Decode*` functions, the type `Primitive` may +// be given to any value, and its decoding will be delayed. +// +// A `Primitive` value can be decoded using the `PrimitiveDecode` function. +// +// The underlying representation of a `Primitive` value is subject to change. +// Do not rely on it. +// +// N.B. Primitive values are still parsed, so using them will only avoid +// the overhead of reflection. They can be useful when you don't know the +// exact type of TOML data until run time. +type Primitive struct { + undecoded interface{} + context Key +} + +// DEPRECATED! +// +// Use MetaData.PrimitiveDecode instead. +func PrimitiveDecode(primValue Primitive, v interface{}) error { + md := MetaData{decoded: make(map[string]bool)} + return md.unify(primValue.undecoded, rvalue(v)) +} + +// PrimitiveDecode is just like the other `Decode*` functions, except it +// decodes a TOML value that has already been parsed. Valid primitive values +// can *only* be obtained from values filled by the decoder functions, +// including this method. (i.e., `v` may contain more `Primitive` +// values.) +// +// Meta data for primitive values is included in the meta data returned by +// the `Decode*` functions with one exception: keys returned by the Undecoded +// method will only reflect keys that were decoded. Namely, any keys hidden +// behind a Primitive will be considered undecoded. Executing this method will +// update the undecoded keys in the meta data. (See the example.) +func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error { + md.context = primValue.context + defer func() { md.context = nil }() + return md.unify(primValue.undecoded, rvalue(v)) +} + +// Decode will decode the contents of `data` in TOML format into a pointer +// `v`. +// +// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be +// used interchangeably.) +// +// TOML arrays of tables correspond to either a slice of structs or a slice +// of maps. +// +// TOML datetimes correspond to Go `time.Time` values. +// +// All other TOML types (float, string, int, bool and array) correspond +// to the obvious Go types. +// +// An exception to the above rules is if a type implements the +// encoding.TextUnmarshaler interface. In this case, any primitive TOML value +// (floats, strings, integers, booleans and datetimes) will be converted to +// a byte string and given to the value's UnmarshalText method. See the +// Unmarshaler example for a demonstration with time duration strings. +// +// Key mapping +// +// TOML keys can map to either keys in a Go map or field names in a Go +// struct. The special `toml` struct tag may be used to map TOML keys to +// struct fields that don't match the key name exactly. (See the example.) +// A case insensitive match to struct names will be tried if an exact match +// can't be found. +// +// The mapping between TOML values and Go values is loose. That is, there +// may exist TOML values that cannot be placed into your representation, and +// there may be parts of your representation that do not correspond to +// TOML values. This loose mapping can be made stricter by using the IsDefined +// and/or Undecoded methods on the MetaData returned. +// +// This decoder will not handle cyclic types. If a cyclic type is passed, +// `Decode` will not terminate. +func Decode(data string, v interface{}) (MetaData, error) { + p, err := parse(data) + if err != nil { + return MetaData{}, err + } + md := MetaData{ + p.mapping, p.types, p.ordered, + make(map[string]bool, len(p.ordered)), nil, + } + return md, md.unify(p.mapping, rvalue(v)) +} + +// DecodeFile is just like Decode, except it will automatically read the +// contents of the file at `fpath` and decode it for you. +func DecodeFile(fpath string, v interface{}) (MetaData, error) { + bs, err := ioutil.ReadFile(fpath) + if err != nil { + return MetaData{}, err + } + return Decode(string(bs), v) +} + +// DecodeReader is just like Decode, except it will consume all bytes +// from the reader and decode it for you. +func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { + bs, err := ioutil.ReadAll(r) + if err != nil { + return MetaData{}, err + } + return Decode(string(bs), v) +} + +// unify performs a sort of type unification based on the structure of `rv`, +// which is the client representation. +// +// Any type mismatch produces an error. Finding a type that we don't know +// how to handle produces an unsupported type error. +func (md *MetaData) unify(data interface{}, rv reflect.Value) error { + + // Special case. Look for a `Primitive` value. + if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() { + // Save the undecoded data and the key context into the primitive + // value. + context := make(Key, len(md.context)) + copy(context, md.context) + rv.Set(reflect.ValueOf(Primitive{ + undecoded: data, + context: context, + })) + return nil + } + + // Special case. Unmarshaler Interface support. + if rv.CanAddr() { + if v, ok := rv.Addr().Interface().(Unmarshaler); ok { + return v.UnmarshalTOML(data) + } + } + + // Special case. Handle time.Time values specifically. + // TODO: Remove this code when we decide to drop support for Go 1.1. + // This isn't necessary in Go 1.2 because time.Time satisfies the encoding + // interfaces. + if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) { + return md.unifyDatetime(data, rv) + } + + // Special case. Look for a value satisfying the TextUnmarshaler interface. + if v, ok := rv.Interface().(TextUnmarshaler); ok { + return md.unifyText(data, v) + } + // BUG(burntsushi) + // The behavior here is incorrect whenever a Go type satisfies the + // encoding.TextUnmarshaler interface but also corresponds to a TOML + // hash or array. In particular, the unmarshaler should only be applied + // to primitive TOML values. But at this point, it will be applied to + // all kinds of values and produce an incorrect error whenever those values + // are hashes or arrays (including arrays of tables). + + k := rv.Kind() + + // laziness + if k >= reflect.Int && k <= reflect.Uint64 { + return md.unifyInt(data, rv) + } + switch k { + case reflect.Ptr: + elem := reflect.New(rv.Type().Elem()) + err := md.unify(data, reflect.Indirect(elem)) + if err != nil { + return err + } + rv.Set(elem) + return nil + case reflect.Struct: + return md.unifyStruct(data, rv) + case reflect.Map: + return md.unifyMap(data, rv) + case reflect.Array: + return md.unifyArray(data, rv) + case reflect.Slice: + return md.unifySlice(data, rv) + case reflect.String: + return md.unifyString(data, rv) + case reflect.Bool: + return md.unifyBool(data, rv) + case reflect.Interface: + // we only support empty interfaces. + if rv.NumMethod() > 0 { + return e("Unsupported type '%s'.", rv.Kind()) + } + return md.unifyAnything(data, rv) + case reflect.Float32: + fallthrough + case reflect.Float64: + return md.unifyFloat64(data, rv) + } + return e("Unsupported type '%s'.", rv.Kind()) +} + +func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { + tmap, ok := mapping.(map[string]interface{}) + if !ok { + return mismatch(rv, "map", mapping) + } + + for key, datum := range tmap { + var f *field + fields := cachedTypeFields(rv.Type()) + for i := range fields { + ff := &fields[i] + if ff.name == key { + f = ff + break + } + if f == nil && strings.EqualFold(ff.name, key) { + f = ff + } + } + if f != nil { + subv := rv + for _, i := range f.index { + subv = indirect(subv.Field(i)) + } + if isUnifiable(subv) { + md.decoded[md.context.add(key).String()] = true + md.context = append(md.context, key) + if err := md.unify(datum, subv); err != nil { + return e("Type mismatch for '%s.%s': %s", + rv.Type().String(), f.name, err) + } + md.context = md.context[0 : len(md.context)-1] + } else if f.name != "" { + // Bad user! No soup for you! + return e("Field '%s.%s' is unexported, and therefore cannot "+ + "be loaded with reflection.", rv.Type().String(), f.name) + } + } + } + return nil +} + +func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { + tmap, ok := mapping.(map[string]interface{}) + if !ok { + return badtype("map", mapping) + } + if rv.IsNil() { + rv.Set(reflect.MakeMap(rv.Type())) + } + for k, v := range tmap { + md.decoded[md.context.add(k).String()] = true + md.context = append(md.context, k) + + rvkey := indirect(reflect.New(rv.Type().Key())) + rvval := reflect.Indirect(reflect.New(rv.Type().Elem())) + if err := md.unify(v, rvval); err != nil { + return err + } + md.context = md.context[0 : len(md.context)-1] + + rvkey.SetString(k) + rv.SetMapIndex(rvkey, rvval) + } + return nil +} + +func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { + datav := reflect.ValueOf(data) + if datav.Kind() != reflect.Slice { + return badtype("slice", data) + } + sliceLen := datav.Len() + if sliceLen != rv.Len() { + return e("expected array length %d; got TOML array of length %d", + rv.Len(), sliceLen) + } + return md.unifySliceArray(datav, rv) +} + +func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error { + datav := reflect.ValueOf(data) + if datav.Kind() != reflect.Slice { + return badtype("slice", data) + } + sliceLen := datav.Len() + if rv.IsNil() { + rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen)) + } + return md.unifySliceArray(datav, rv) +} + +func (md *MetaData) unifySliceArray(data, rv reflect.Value) error { + sliceLen := data.Len() + for i := 0; i < sliceLen; i++ { + v := data.Index(i).Interface() + sliceval := indirect(rv.Index(i)) + if err := md.unify(v, sliceval); err != nil { + return err + } + } + return nil +} + +func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error { + if _, ok := data.(time.Time); ok { + rv.Set(reflect.ValueOf(data)) + return nil + } + return badtype("time.Time", data) +} + +func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error { + if s, ok := data.(string); ok { + rv.SetString(s) + return nil + } + return badtype("string", data) +} + +func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error { + if num, ok := data.(float64); ok { + switch rv.Kind() { + case reflect.Float32: + fallthrough + case reflect.Float64: + rv.SetFloat(num) + default: + panic("bug") + } + return nil + } + return badtype("float", data) +} + +func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error { + if num, ok := data.(int64); ok { + if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 { + switch rv.Kind() { + case reflect.Int, reflect.Int64: + // No bounds checking necessary. + case reflect.Int8: + if num < math.MinInt8 || num > math.MaxInt8 { + return e("Value '%d' is out of range for int8.", num) + } + case reflect.Int16: + if num < math.MinInt16 || num > math.MaxInt16 { + return e("Value '%d' is out of range for int16.", num) + } + case reflect.Int32: + if num < math.MinInt32 || num > math.MaxInt32 { + return e("Value '%d' is out of range for int32.", num) + } + } + rv.SetInt(num) + } else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 { + unum := uint64(num) + switch rv.Kind() { + case reflect.Uint, reflect.Uint64: + // No bounds checking necessary. + case reflect.Uint8: + if num < 0 || unum > math.MaxUint8 { + return e("Value '%d' is out of range for uint8.", num) + } + case reflect.Uint16: + if num < 0 || unum > math.MaxUint16 { + return e("Value '%d' is out of range for uint16.", num) + } + case reflect.Uint32: + if num < 0 || unum > math.MaxUint32 { + return e("Value '%d' is out of range for uint32.", num) + } + } + rv.SetUint(unum) + } else { + panic("unreachable") + } + return nil + } + return badtype("integer", data) +} + +func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error { + if b, ok := data.(bool); ok { + rv.SetBool(b) + return nil + } + return badtype("boolean", data) +} + +func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error { + rv.Set(reflect.ValueOf(data)) + return nil +} + +func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error { + var s string + switch sdata := data.(type) { + case TextMarshaler: + text, err := sdata.MarshalText() + if err != nil { + return err + } + s = string(text) + case fmt.Stringer: + s = sdata.String() + case string: + s = sdata + case bool: + s = fmt.Sprintf("%v", sdata) + case int64: + s = fmt.Sprintf("%d", sdata) + case float64: + s = fmt.Sprintf("%f", sdata) + default: + return badtype("primitive (string-like)", data) + } + if err := v.UnmarshalText([]byte(s)); err != nil { + return err + } + return nil +} + +// rvalue returns a reflect.Value of `v`. All pointers are resolved. +func rvalue(v interface{}) reflect.Value { + return indirect(reflect.ValueOf(v)) +} + +// indirect returns the value pointed to by a pointer. +// Pointers are followed until the value is not a pointer. +// New values are allocated for each nil pointer. +// +// An exception to this rule is if the value satisfies an interface of +// interest to us (like encoding.TextUnmarshaler). +func indirect(v reflect.Value) reflect.Value { + if v.Kind() != reflect.Ptr { + if v.CanAddr() { + pv := v.Addr() + if _, ok := pv.Interface().(TextUnmarshaler); ok { + return pv + } + } + return v + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + return indirect(reflect.Indirect(v)) +} + +func isUnifiable(rv reflect.Value) bool { + if rv.CanSet() { + return true + } + if _, ok := rv.Interface().(TextUnmarshaler); ok { + return true + } + return false +} + +func badtype(expected string, data interface{}) error { + return e("Expected %s but found '%T'.", expected, data) +} + +func mismatch(user reflect.Value, expected string, data interface{}) error { + return e("Type mismatch for %s. Expected %s but found '%T'.", + user.Type().String(), expected, data) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode_meta.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode_meta.go new file mode 100644 index 0000000..ef6f545 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/decode_meta.go @@ -0,0 +1,122 @@ +package toml + +import "strings" + +// MetaData allows access to meta information about TOML data that may not +// be inferrable via reflection. In particular, whether a key has been defined +// and the TOML type of a key. +type MetaData struct { + mapping map[string]interface{} + types map[string]tomlType + keys []Key + decoded map[string]bool + context Key // Used only during decoding. +} + +// IsDefined returns true if the key given exists in the TOML data. The key +// should be specified hierarchially. e.g., +// +// // access the TOML key 'a.b.c' +// IsDefined("a", "b", "c") +// +// IsDefined will return false if an empty key given. Keys are case sensitive. +func (md *MetaData) IsDefined(key ...string) bool { + if len(key) == 0 { + return false + } + + var hash map[string]interface{} + var ok bool + var hashOrVal interface{} = md.mapping + for _, k := range key { + if hash, ok = hashOrVal.(map[string]interface{}); !ok { + return false + } + if hashOrVal, ok = hash[k]; !ok { + return false + } + } + return true +} + +// Type returns a string representation of the type of the key specified. +// +// Type will return the empty string if given an empty key or a key that +// does not exist. Keys are case sensitive. +func (md *MetaData) Type(key ...string) string { + fullkey := strings.Join(key, ".") + if typ, ok := md.types[fullkey]; ok { + return typ.typeString() + } + return "" +} + +// Key is the type of any TOML key, including key groups. Use (MetaData).Keys +// to get values of this type. +type Key []string + +func (k Key) String() string { + return strings.Join(k, ".") +} + +func (k Key) maybeQuotedAll() string { + var ss []string + for i := range k { + ss = append(ss, k.maybeQuoted(i)) + } + return strings.Join(ss, ".") +} + +func (k Key) maybeQuoted(i int) string { + quote := false + for _, c := range k[i] { + if !isBareKeyChar(c) { + quote = true + break + } + } + if quote { + return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\"" + } else { + return k[i] + } +} + +func (k Key) add(piece string) Key { + newKey := make(Key, len(k)+1) + copy(newKey, k) + newKey[len(k)] = piece + return newKey +} + +// Keys returns a slice of every key in the TOML data, including key groups. +// Each key is itself a slice, where the first element is the top of the +// hierarchy and the last is the most specific. +// +// The list will have the same order as the keys appeared in the TOML data. +// +// All keys returned are non-empty. +func (md *MetaData) Keys() []Key { + return md.keys +} + +// Undecoded returns all keys that have not been decoded in the order in which +// they appear in the original TOML document. +// +// This includes keys that haven't been decoded because of a Primitive value. +// Once the Primitive value is decoded, the keys will be considered decoded. +// +// Also note that decoding into an empty interface will result in no decoding, +// and so no keys will be considered decoded. +// +// In this sense, the Undecoded keys correspond to keys in the TOML document +// that do not have a concrete type in your representation. +func (md *MetaData) Undecoded() []Key { + undecoded := make([]Key, 0, len(md.keys)) + for _, key := range md.keys { + if !md.decoded[key.String()] { + undecoded = append(undecoded, key) + } + } + return undecoded +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/doc.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/doc.go new file mode 100644 index 0000000..fe26800 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/doc.go @@ -0,0 +1,27 @@ +/* +Package toml provides facilities for decoding and encoding TOML configuration +files via reflection. There is also support for delaying decoding with +the Primitive type, and querying the set of keys in a TOML document with the +MetaData type. + +The specification implemented: https://github.com/mojombo/toml + +The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify +whether a file is a valid TOML document. It can also be used to print the +type of each key in a TOML document. + +Testing + +There are two important types of tests used for this package. The first is +contained inside '*_test.go' files and uses the standard Go unit testing +framework. These tests are primarily devoted to holistically testing the +decoder and encoder. + +The second type of testing is used to verify the implementation's adherence +to the TOML specification. These tests have been factored into their own +project: https://github.com/BurntSushi/toml-test + +The reason the tests are in a separate project is so that they can be used by +any implementation of TOML. Namely, it is language agnostic. +*/ +package toml diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encode.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encode.go new file mode 100644 index 0000000..c7e227c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encode.go @@ -0,0 +1,551 @@ +package toml + +import ( + "bufio" + "errors" + "fmt" + "io" + "reflect" + "sort" + "strconv" + "strings" + "time" +) + +type tomlEncodeError struct{ error } + +var ( + errArrayMixedElementTypes = errors.New( + "can't encode array with mixed element types") + errArrayNilElement = errors.New( + "can't encode array with nil element") + errNonString = errors.New( + "can't encode a map with non-string key type") + errAnonNonStruct = errors.New( + "can't encode an anonymous field that is not a struct") + errArrayNoTable = errors.New( + "TOML array element can't contain a table") + errNoKey = errors.New( + "top-level values must be a Go map or struct") + errAnything = errors.New("") // used in testing +) + +var quotedReplacer = strings.NewReplacer( + "\t", "\\t", + "\n", "\\n", + "\r", "\\r", + "\"", "\\\"", + "\\", "\\\\", +) + +// Encoder controls the encoding of Go values to a TOML document to some +// io.Writer. +// +// The indentation level can be controlled with the Indent field. +type Encoder struct { + // A single indentation level. By default it is two spaces. + Indent string + + // hasWritten is whether we have written any output to w yet. + hasWritten bool + w *bufio.Writer +} + +// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer +// given. By default, a single indentation level is 2 spaces. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + w: bufio.NewWriter(w), + Indent: " ", + } +} + +// Encode writes a TOML representation of the Go value to the underlying +// io.Writer. If the value given cannot be encoded to a valid TOML document, +// then an error is returned. +// +// The mapping between Go values and TOML values should be precisely the same +// as for the Decode* functions. Similarly, the TextMarshaler interface is +// supported by encoding the resulting bytes as strings. (If you want to write +// arbitrary binary data then you will need to use something like base64 since +// TOML does not have any binary types.) +// +// When encoding TOML hashes (i.e., Go maps or structs), keys without any +// sub-hashes are encoded first. +// +// If a Go map is encoded, then its keys are sorted alphabetically for +// deterministic output. More control over this behavior may be provided if +// there is demand for it. +// +// Encoding Go values without a corresponding TOML representation---like map +// types with non-string keys---will cause an error to be returned. Similarly +// for mixed arrays/slices, arrays/slices with nil elements, embedded +// non-struct types and nested slices containing maps or structs. +// (e.g., [][]map[string]string is not allowed but []map[string]string is OK +// and so is []map[string][]string.) +func (enc *Encoder) Encode(v interface{}) error { + rv := eindirect(reflect.ValueOf(v)) + if err := enc.safeEncode(Key([]string{}), rv); err != nil { + return err + } + return enc.w.Flush() +} + +func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) { + defer func() { + if r := recover(); r != nil { + if terr, ok := r.(tomlEncodeError); ok { + err = terr.error + return + } + panic(r) + } + }() + enc.encode(key, rv) + return nil +} + +func (enc *Encoder) encode(key Key, rv reflect.Value) { + // Special case. Time needs to be in ISO8601 format. + // Special case. If we can marshal the type to text, then we used that. + // Basically, this prevents the encoder for handling these types as + // generic structs (or whatever the underlying type of a TextMarshaler is). + switch rv.Interface().(type) { + case time.Time, TextMarshaler: + enc.keyEqElement(key, rv) + return + } + + k := rv.Kind() + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, + reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: + enc.keyEqElement(key, rv) + case reflect.Array, reflect.Slice: + if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) { + enc.eArrayOfTables(key, rv) + } else { + enc.keyEqElement(key, rv) + } + case reflect.Interface: + if rv.IsNil() { + return + } + enc.encode(key, rv.Elem()) + case reflect.Map: + if rv.IsNil() { + return + } + enc.eTable(key, rv) + case reflect.Ptr: + if rv.IsNil() { + return + } + enc.encode(key, rv.Elem()) + case reflect.Struct: + enc.eTable(key, rv) + default: + panic(e("Unsupported type for key '%s': %s", key, k)) + } +} + +// eElement encodes any value that can be an array element (primitives and +// arrays). +func (enc *Encoder) eElement(rv reflect.Value) { + switch v := rv.Interface().(type) { + case time.Time: + // Special case time.Time as a primitive. Has to come before + // TextMarshaler below because time.Time implements + // encoding.TextMarshaler, but we need to always use UTC. + enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z")) + return + case TextMarshaler: + // Special case. Use text marshaler if it's available for this value. + if s, err := v.MarshalText(); err != nil { + encPanic(err) + } else { + enc.writeQuoted(string(s)) + } + return + } + switch rv.Kind() { + case reflect.Bool: + enc.wf(strconv.FormatBool(rv.Bool())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64: + enc.wf(strconv.FormatInt(rv.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, + reflect.Uint32, reflect.Uint64: + enc.wf(strconv.FormatUint(rv.Uint(), 10)) + case reflect.Float32: + enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32))) + case reflect.Float64: + enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64))) + case reflect.Array, reflect.Slice: + enc.eArrayOrSliceElement(rv) + case reflect.Interface: + enc.eElement(rv.Elem()) + case reflect.String: + enc.writeQuoted(rv.String()) + default: + panic(e("Unexpected primitive type: %s", rv.Kind())) + } +} + +// By the TOML spec, all floats must have a decimal with at least one +// number on either side. +func floatAddDecimal(fstr string) string { + if !strings.Contains(fstr, ".") { + return fstr + ".0" + } + return fstr +} + +func (enc *Encoder) writeQuoted(s string) { + enc.wf("\"%s\"", quotedReplacer.Replace(s)) +} + +func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { + length := rv.Len() + enc.wf("[") + for i := 0; i < length; i++ { + elem := rv.Index(i) + enc.eElement(elem) + if i != length-1 { + enc.wf(", ") + } + } + enc.wf("]") +} + +func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { + if len(key) == 0 { + encPanic(errNoKey) + } + for i := 0; i < rv.Len(); i++ { + trv := rv.Index(i) + if isNil(trv) { + continue + } + panicIfInvalidKey(key) + enc.newline() + enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll()) + enc.newline() + enc.eMapOrStruct(key, trv) + } +} + +func (enc *Encoder) eTable(key Key, rv reflect.Value) { + panicIfInvalidKey(key) + if len(key) == 1 { + // Output an extra new line between top-level tables. + // (The newline isn't written if nothing else has been written though.) + enc.newline() + } + if len(key) > 0 { + enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll()) + enc.newline() + } + enc.eMapOrStruct(key, rv) +} + +func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) { + switch rv := eindirect(rv); rv.Kind() { + case reflect.Map: + enc.eMap(key, rv) + case reflect.Struct: + enc.eStruct(key, rv) + default: + panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String()) + } +} + +func (enc *Encoder) eMap(key Key, rv reflect.Value) { + rt := rv.Type() + if rt.Key().Kind() != reflect.String { + encPanic(errNonString) + } + + // Sort keys so that we have deterministic output. And write keys directly + // underneath this key first, before writing sub-structs or sub-maps. + var mapKeysDirect, mapKeysSub []string + for _, mapKey := range rv.MapKeys() { + k := mapKey.String() + if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) { + mapKeysSub = append(mapKeysSub, k) + } else { + mapKeysDirect = append(mapKeysDirect, k) + } + } + + var writeMapKeys = func(mapKeys []string) { + sort.Strings(mapKeys) + for _, mapKey := range mapKeys { + mrv := rv.MapIndex(reflect.ValueOf(mapKey)) + if isNil(mrv) { + // Don't write anything for nil fields. + continue + } + enc.encode(key.add(mapKey), mrv) + } + } + writeMapKeys(mapKeysDirect) + writeMapKeys(mapKeysSub) +} + +func (enc *Encoder) eStruct(key Key, rv reflect.Value) { + // Write keys for fields directly under this key first, because if we write + // a field that creates a new table, then all keys under it will be in that + // table (not the one we're writing here). + rt := rv.Type() + var fieldsDirect, fieldsSub [][]int + var addFields func(rt reflect.Type, rv reflect.Value, start []int) + addFields = func(rt reflect.Type, rv reflect.Value, start []int) { + for i := 0; i < rt.NumField(); i++ { + f := rt.Field(i) + // skip unexporded fields + if f.PkgPath != "" { + continue + } + frv := rv.Field(i) + if f.Anonymous { + frv := eindirect(frv) + t := frv.Type() + if t.Kind() != reflect.Struct { + encPanic(errAnonNonStruct) + } + addFields(t, frv, f.Index) + } else if typeIsHash(tomlTypeOfGo(frv)) { + fieldsSub = append(fieldsSub, append(start, f.Index...)) + } else { + fieldsDirect = append(fieldsDirect, append(start, f.Index...)) + } + } + } + addFields(rt, rv, nil) + + var writeFields = func(fields [][]int) { + for _, fieldIndex := range fields { + sft := rt.FieldByIndex(fieldIndex) + sf := rv.FieldByIndex(fieldIndex) + if isNil(sf) { + // Don't write anything for nil fields. + continue + } + + keyName := sft.Tag.Get("toml") + if keyName == "-" { + continue + } + if keyName == "" { + keyName = sft.Name + } + + keyName, opts := getOptions(keyName) + if _, ok := opts["omitempty"]; ok && isEmpty(sf) { + continue + } else if _, ok := opts["omitzero"]; ok && isZero(sf) { + continue + } + + enc.encode(key.add(keyName), sf) + } + } + writeFields(fieldsDirect) + writeFields(fieldsSub) +} + +// tomlTypeName returns the TOML type name of the Go value's type. It is +// used to determine whether the types of array elements are mixed (which is +// forbidden). If the Go value is nil, then it is illegal for it to be an array +// element, and valueIsNil is returned as true. + +// Returns the TOML type of a Go value. The type may be `nil`, which means +// no concrete TOML type could be found. +func tomlTypeOfGo(rv reflect.Value) tomlType { + if isNil(rv) || !rv.IsValid() { + return nil + } + switch rv.Kind() { + case reflect.Bool: + return tomlBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, + reflect.Uint64: + return tomlInteger + case reflect.Float32, reflect.Float64: + return tomlFloat + case reflect.Array, reflect.Slice: + if typeEqual(tomlHash, tomlArrayType(rv)) { + return tomlArrayHash + } else { + return tomlArray + } + case reflect.Ptr, reflect.Interface: + return tomlTypeOfGo(rv.Elem()) + case reflect.String: + return tomlString + case reflect.Map: + return tomlHash + case reflect.Struct: + switch rv.Interface().(type) { + case time.Time: + return tomlDatetime + case TextMarshaler: + return tomlString + default: + return tomlHash + } + default: + panic("unexpected reflect.Kind: " + rv.Kind().String()) + } +} + +// tomlArrayType returns the element type of a TOML array. The type returned +// may be nil if it cannot be determined (e.g., a nil slice or a zero length +// slize). This function may also panic if it finds a type that cannot be +// expressed in TOML (such as nil elements, heterogeneous arrays or directly +// nested arrays of tables). +func tomlArrayType(rv reflect.Value) tomlType { + if isNil(rv) || !rv.IsValid() || rv.Len() == 0 { + return nil + } + firstType := tomlTypeOfGo(rv.Index(0)) + if firstType == nil { + encPanic(errArrayNilElement) + } + + rvlen := rv.Len() + for i := 1; i < rvlen; i++ { + elem := rv.Index(i) + switch elemType := tomlTypeOfGo(elem); { + case elemType == nil: + encPanic(errArrayNilElement) + case !typeEqual(firstType, elemType): + encPanic(errArrayMixedElementTypes) + } + } + // If we have a nested array, then we must make sure that the nested + // array contains ONLY primitives. + // This checks arbitrarily nested arrays. + if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) { + nest := tomlArrayType(eindirect(rv.Index(0))) + if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) { + encPanic(errArrayNoTable) + } + } + return firstType +} + +func getOptions(keyName string) (string, map[string]struct{}) { + opts := make(map[string]struct{}) + ss := strings.Split(keyName, ",") + name := ss[0] + if len(ss) > 1 { + for _, opt := range ss { + opts[opt] = struct{}{} + } + } + + return name, opts +} + +func isZero(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if rv.Int() == 0 { + return true + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if rv.Uint() == 0 { + return true + } + case reflect.Float32, reflect.Float64: + if rv.Float() == 0.0 { + return true + } + } + + return false +} + +func isEmpty(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.String: + if len(strings.TrimSpace(rv.String())) == 0 { + return true + } + case reflect.Array, reflect.Slice, reflect.Map: + if rv.Len() == 0 { + return true + } + } + + return false +} + +func (enc *Encoder) newline() { + if enc.hasWritten { + enc.wf("\n") + } +} + +func (enc *Encoder) keyEqElement(key Key, val reflect.Value) { + if len(key) == 0 { + encPanic(errNoKey) + } + panicIfInvalidKey(key) + enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) + enc.eElement(val) + enc.newline() +} + +func (enc *Encoder) wf(format string, v ...interface{}) { + if _, err := fmt.Fprintf(enc.w, format, v...); err != nil { + encPanic(err) + } + enc.hasWritten = true +} + +func (enc *Encoder) indentStr(key Key) string { + return strings.Repeat(enc.Indent, len(key)-1) +} + +func encPanic(err error) { + panic(tomlEncodeError{err}) +} + +func eindirect(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + return eindirect(v.Elem()) + default: + return v + } +} + +func isNil(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return rv.IsNil() + default: + return false + } +} + +func panicIfInvalidKey(key Key) { + for _, k := range key { + if len(k) == 0 { + encPanic(e("Key '%s' is not a valid table name. Key names "+ + "cannot be empty.", key.maybeQuotedAll())) + } + } +} + +func isValidKeyName(s string) bool { + return len(s) != 0 +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types.go new file mode 100644 index 0000000..d36e1dd --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types.go @@ -0,0 +1,19 @@ +// +build go1.2 + +package toml + +// In order to support Go 1.1, we define our own TextMarshaler and +// TextUnmarshaler types. For Go 1.2+, we just alias them with the +// standard library interfaces. + +import ( + "encoding" +) + +// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here +// so that Go 1.1 can be supported. +type TextMarshaler encoding.TextMarshaler + +// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined +// here so that Go 1.1 can be supported. +type TextUnmarshaler encoding.TextUnmarshaler diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go new file mode 100644 index 0000000..e8d503d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go @@ -0,0 +1,18 @@ +// +build !go1.2 + +package toml + +// These interfaces were introduced in Go 1.2, so we add them manually when +// compiling for Go 1.1. + +// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here +// so that Go 1.1 can be supported. +type TextMarshaler interface { + MarshalText() (text []byte, err error) +} + +// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined +// here so that Go 1.1 can be supported. +type TextUnmarshaler interface { + UnmarshalText(text []byte) error +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/lex.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/lex.go new file mode 100644 index 0000000..2191228 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/lex.go @@ -0,0 +1,874 @@ +package toml + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +type itemType int + +const ( + itemError itemType = iota + itemNIL // used in the parser to indicate no type + itemEOF + itemText + itemString + itemRawString + itemMultilineString + itemRawMultilineString + itemBool + itemInteger + itemFloat + itemDatetime + itemArray // the start of an array + itemArrayEnd + itemTableStart + itemTableEnd + itemArrayTableStart + itemArrayTableEnd + itemKeyStart + itemCommentStart +) + +const ( + eof = 0 + tableStart = '[' + tableEnd = ']' + arrayTableStart = '[' + arrayTableEnd = ']' + tableSep = '.' + keySep = '=' + arrayStart = '[' + arrayEnd = ']' + arrayValTerm = ',' + commentStart = '#' + stringStart = '"' + stringEnd = '"' + rawStringStart = '\'' + rawStringEnd = '\'' +) + +type stateFn func(lx *lexer) stateFn + +type lexer struct { + input string + start int + pos int + width int + line int + state stateFn + items chan item + + // A stack of state functions used to maintain context. + // The idea is to reuse parts of the state machine in various places. + // For example, values can appear at the top level or within arbitrarily + // nested arrays. The last state on the stack is used after a value has + // been lexed. Similarly for comments. + stack []stateFn +} + +type item struct { + typ itemType + val string + line int +} + +func (lx *lexer) nextItem() item { + for { + select { + case item := <-lx.items: + return item + default: + lx.state = lx.state(lx) + } + } +} + +func lex(input string) *lexer { + lx := &lexer{ + input: input + "\n", + state: lexTop, + line: 1, + items: make(chan item, 10), + stack: make([]stateFn, 0, 10), + } + return lx +} + +func (lx *lexer) push(state stateFn) { + lx.stack = append(lx.stack, state) +} + +func (lx *lexer) pop() stateFn { + if len(lx.stack) == 0 { + return lx.errorf("BUG in lexer: no states to pop.") + } + last := lx.stack[len(lx.stack)-1] + lx.stack = lx.stack[0 : len(lx.stack)-1] + return last +} + +func (lx *lexer) current() string { + return lx.input[lx.start:lx.pos] +} + +func (lx *lexer) emit(typ itemType) { + lx.items <- item{typ, lx.current(), lx.line} + lx.start = lx.pos +} + +func (lx *lexer) emitTrim(typ itemType) { + lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line} + lx.start = lx.pos +} + +func (lx *lexer) next() (r rune) { + if lx.pos >= len(lx.input) { + lx.width = 0 + return eof + } + + if lx.input[lx.pos] == '\n' { + lx.line++ + } + r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:]) + lx.pos += lx.width + return r +} + +// ignore skips over the pending input before this point. +func (lx *lexer) ignore() { + lx.start = lx.pos +} + +// backup steps back one rune. Can be called only once per call of next. +func (lx *lexer) backup() { + lx.pos -= lx.width + if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' { + lx.line-- + } +} + +// accept consumes the next rune if it's equal to `valid`. +func (lx *lexer) accept(valid rune) bool { + if lx.next() == valid { + return true + } + lx.backup() + return false +} + +// peek returns but does not consume the next rune in the input. +func (lx *lexer) peek() rune { + r := lx.next() + lx.backup() + return r +} + +// errorf stops all lexing by emitting an error and returning `nil`. +// Note that any value that is a character is escaped if it's a special +// character (new lines, tabs, etc.). +func (lx *lexer) errorf(format string, values ...interface{}) stateFn { + lx.items <- item{ + itemError, + fmt.Sprintf(format, values...), + lx.line, + } + return nil +} + +// lexTop consumes elements at the top level of TOML data. +func lexTop(lx *lexer) stateFn { + r := lx.next() + if isWhitespace(r) || isNL(r) { + return lexSkip(lx, lexTop) + } + + switch r { + case commentStart: + lx.push(lexTop) + return lexCommentStart + case tableStart: + return lexTableStart + case eof: + if lx.pos > lx.start { + return lx.errorf("Unexpected EOF.") + } + lx.emit(itemEOF) + return nil + } + + // At this point, the only valid item can be a key, so we back up + // and let the key lexer do the rest. + lx.backup() + lx.push(lexTopEnd) + return lexKeyStart +} + +// lexTopEnd is entered whenever a top-level item has been consumed. (A value +// or a table.) It must see only whitespace, and will turn back to lexTop +// upon a new line. If it sees EOF, it will quit the lexer successfully. +func lexTopEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case r == commentStart: + // a comment will read to a new line for us. + lx.push(lexTop) + return lexCommentStart + case isWhitespace(r): + return lexTopEnd + case isNL(r): + lx.ignore() + return lexTop + case r == eof: + lx.ignore() + return lexTop + } + return lx.errorf("Expected a top-level item to end with a new line, "+ + "comment or EOF, but got %q instead.", r) +} + +// lexTable lexes the beginning of a table. Namely, it makes sure that +// it starts with a character other than '.' and ']'. +// It assumes that '[' has already been consumed. +// It also handles the case that this is an item in an array of tables. +// e.g., '[[name]]'. +func lexTableStart(lx *lexer) stateFn { + if lx.peek() == arrayTableStart { + lx.next() + lx.emit(itemArrayTableStart) + lx.push(lexArrayTableEnd) + } else { + lx.emit(itemTableStart) + lx.push(lexTableEnd) + } + return lexTableNameStart +} + +func lexTableEnd(lx *lexer) stateFn { + lx.emit(itemTableEnd) + return lexTopEnd +} + +func lexArrayTableEnd(lx *lexer) stateFn { + if r := lx.next(); r != arrayTableEnd { + return lx.errorf("Expected end of table array name delimiter %q, "+ + "but got %q instead.", arrayTableEnd, r) + } + lx.emit(itemArrayTableEnd) + return lexTopEnd +} + +func lexTableNameStart(lx *lexer) stateFn { + switch r := lx.peek(); { + case r == tableEnd || r == eof: + return lx.errorf("Unexpected end of table name. (Table names cannot " + + "be empty.)") + case r == tableSep: + return lx.errorf("Unexpected table separator. (Table names cannot " + + "be empty.)") + case r == stringStart || r == rawStringStart: + lx.ignore() + lx.push(lexTableNameEnd) + return lexValue // reuse string lexing + case isWhitespace(r): + return lexTableNameStart + default: + return lexBareTableName + } +} + +// lexTableName lexes the name of a table. It assumes that at least one +// valid character for the table has already been read. +func lexBareTableName(lx *lexer) stateFn { + switch r := lx.next(); { + case isBareKeyChar(r): + return lexBareTableName + case r == tableSep || r == tableEnd: + lx.backup() + lx.emitTrim(itemText) + return lexTableNameEnd + default: + return lx.errorf("Bare keys cannot contain %q.", r) + } +} + +// lexTableNameEnd reads the end of a piece of a table name, optionally +// consuming whitespace. +func lexTableNameEnd(lx *lexer) stateFn { + switch r := lx.next(); { + case isWhitespace(r): + return lexTableNameEnd + case r == tableSep: + lx.ignore() + return lexTableNameStart + case r == tableEnd: + return lx.pop() + default: + return lx.errorf("Expected '.' or ']' to end table name, but got %q "+ + "instead.", r) + } +} + +// lexKeyStart consumes a key name up until the first non-whitespace character. +// lexKeyStart will ignore whitespace. +func lexKeyStart(lx *lexer) stateFn { + r := lx.peek() + switch { + case r == keySep: + return lx.errorf("Unexpected key separator %q.", keySep) + case isWhitespace(r) || isNL(r): + lx.next() + return lexSkip(lx, lexKeyStart) + case r == stringStart || r == rawStringStart: + lx.ignore() + lx.emit(itemKeyStart) + lx.push(lexKeyEnd) + return lexValue // reuse string lexing + default: + lx.ignore() + lx.emit(itemKeyStart) + return lexBareKey + } +} + +// lexBareKey consumes the text of a bare key. Assumes that the first character +// (which is not whitespace) has not yet been consumed. +func lexBareKey(lx *lexer) stateFn { + switch r := lx.next(); { + case isBareKeyChar(r): + return lexBareKey + case isWhitespace(r): + lx.emitTrim(itemText) + return lexKeyEnd + case r == keySep: + lx.backup() + lx.emitTrim(itemText) + return lexKeyEnd + default: + return lx.errorf("Bare keys cannot contain %q.", r) + } +} + +// lexKeyEnd consumes the end of a key and trims whitespace (up to the key +// separator). +func lexKeyEnd(lx *lexer) stateFn { + switch r := lx.next(); { + case r == keySep: + return lexSkip(lx, lexValue) + case isWhitespace(r): + return lexSkip(lx, lexKeyEnd) + default: + return lx.errorf("Expected key separator %q, but got %q instead.", + keySep, r) + } +} + +// lexValue starts the consumption of a value anywhere a value is expected. +// lexValue will ignore whitespace. +// After a value is lexed, the last state on the next is popped and returned. +func lexValue(lx *lexer) stateFn { + // We allow whitespace to precede a value, but NOT new lines. + // In array syntax, the array states are responsible for ignoring new + // lines. + r := lx.next() + if isWhitespace(r) { + return lexSkip(lx, lexValue) + } + + switch { + case r == arrayStart: + lx.ignore() + lx.emit(itemArray) + return lexArrayValue + case r == stringStart: + if lx.accept(stringStart) { + if lx.accept(stringStart) { + lx.ignore() // Ignore """ + return lexMultilineString + } + lx.backup() + } + lx.ignore() // ignore the '"' + return lexString + case r == rawStringStart: + if lx.accept(rawStringStart) { + if lx.accept(rawStringStart) { + lx.ignore() // Ignore """ + return lexMultilineRawString + } + lx.backup() + } + lx.ignore() // ignore the "'" + return lexRawString + case r == 't': + return lexTrue + case r == 'f': + return lexFalse + case r == '-': + return lexNumberStart + case isDigit(r): + lx.backup() // avoid an extra state and use the same as above + return lexNumberOrDateStart + case r == '.': // special error case, be kind to users + return lx.errorf("Floats must start with a digit, not '.'.") + } + return lx.errorf("Expected value but found %q instead.", r) +} + +// lexArrayValue consumes one value in an array. It assumes that '[' or ',' +// have already been consumed. All whitespace and new lines are ignored. +func lexArrayValue(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r) || isNL(r): + return lexSkip(lx, lexArrayValue) + case r == commentStart: + lx.push(lexArrayValue) + return lexCommentStart + case r == arrayValTerm: + return lx.errorf("Unexpected array value terminator %q.", + arrayValTerm) + case r == arrayEnd: + return lexArrayEnd + } + + lx.backup() + lx.push(lexArrayValueEnd) + return lexValue +} + +// lexArrayValueEnd consumes the cruft between values of an array. Namely, +// it ignores whitespace and expects either a ',' or a ']'. +func lexArrayValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r) || isNL(r): + return lexSkip(lx, lexArrayValueEnd) + case r == commentStart: + lx.push(lexArrayValueEnd) + return lexCommentStart + case r == arrayValTerm: + lx.ignore() + return lexArrayValue // move on to the next value + case r == arrayEnd: + return lexArrayEnd + } + return lx.errorf("Expected an array value terminator %q or an array "+ + "terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r) +} + +// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has +// just been consumed. +func lexArrayEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemArrayEnd) + return lx.pop() +} + +// lexString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. +func lexString(lx *lexer) stateFn { + r := lx.next() + switch { + case isNL(r): + return lx.errorf("Strings cannot contain new lines.") + case r == '\\': + lx.push(lexString) + return lexStringEscape + case r == stringEnd: + lx.backup() + lx.emit(itemString) + lx.next() + lx.ignore() + return lx.pop() + } + return lexString +} + +// lexMultilineString consumes the inner contents of a string. It assumes that +// the beginning '"""' has already been consumed and ignored. +func lexMultilineString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '\\': + return lexMultilineStringEscape + case r == stringEnd: + if lx.accept(stringEnd) { + if lx.accept(stringEnd) { + lx.backup() + lx.backup() + lx.backup() + lx.emit(itemMultilineString) + lx.next() + lx.next() + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + } + return lexMultilineString +} + +// lexRawString consumes a raw string. Nothing can be escaped in such a string. +// It assumes that the beginning "'" has already been consumed and ignored. +func lexRawString(lx *lexer) stateFn { + r := lx.next() + switch { + case isNL(r): + return lx.errorf("Strings cannot contain new lines.") + case r == rawStringEnd: + lx.backup() + lx.emit(itemRawString) + lx.next() + lx.ignore() + return lx.pop() + } + return lexRawString +} + +// lexMultilineRawString consumes a raw string. Nothing can be escaped in such +// a string. It assumes that the beginning "'" has already been consumed and +// ignored. +func lexMultilineRawString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == rawStringEnd: + if lx.accept(rawStringEnd) { + if lx.accept(rawStringEnd) { + lx.backup() + lx.backup() + lx.backup() + lx.emit(itemRawMultilineString) + lx.next() + lx.next() + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + } + return lexMultilineRawString +} + +// lexMultilineStringEscape consumes an escaped character. It assumes that the +// preceding '\\' has already been consumed. +func lexMultilineStringEscape(lx *lexer) stateFn { + // Handle the special case first: + if isNL(lx.next()) { + lx.next() + return lexMultilineString + } else { + lx.backup() + lx.push(lexMultilineString) + return lexStringEscape(lx) + } +} + +func lexStringEscape(lx *lexer) stateFn { + r := lx.next() + switch r { + case 'b': + fallthrough + case 't': + fallthrough + case 'n': + fallthrough + case 'f': + fallthrough + case 'r': + fallthrough + case '"': + fallthrough + case '\\': + return lx.pop() + case 'u': + return lexShortUnicodeEscape + case 'U': + return lexLongUnicodeEscape + } + return lx.errorf("Invalid escape character %q. Only the following "+ + "escape characters are allowed: "+ + "\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+ + "\\uXXXX and \\UXXXXXXXX.", r) +} + +func lexShortUnicodeEscape(lx *lexer) stateFn { + var r rune + for i := 0; i < 4; i++ { + r = lx.next() + if !isHexadecimal(r) { + return lx.errorf("Expected four hexadecimal digits after '\\u', "+ + "but got '%s' instead.", lx.current()) + } + } + return lx.pop() +} + +func lexLongUnicodeEscape(lx *lexer) stateFn { + var r rune + for i := 0; i < 8; i++ { + r = lx.next() + if !isHexadecimal(r) { + return lx.errorf("Expected eight hexadecimal digits after '\\U', "+ + "but got '%s' instead.", lx.current()) + } + } + return lx.pop() +} + +// lexNumberOrDateStart consumes either a (positive) integer, float or +// datetime. It assumes that NO negative sign has been consumed. +func lexNumberOrDateStart(lx *lexer) stateFn { + r := lx.next() + if !isDigit(r) { + if r == '.' { + return lx.errorf("Floats must start with a digit, not '.'.") + } else { + return lx.errorf("Expected a digit but got %q.", r) + } + } + return lexNumberOrDate +} + +// lexNumberOrDate consumes either a (positive) integer, float or datetime. +func lexNumberOrDate(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '-': + if lx.pos-lx.start != 5 { + return lx.errorf("All ISO8601 dates must be in full Zulu form.") + } + return lexDateAfterYear + case isDigit(r): + return lexNumberOrDate + case r == '.': + return lexFloatStart + } + + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format. +// It assumes that "YYYY-" has already been consumed. +func lexDateAfterYear(lx *lexer) stateFn { + formats := []rune{ + // digits are '0'. + // everything else is direct equality. + '0', '0', '-', '0', '0', + 'T', + '0', '0', ':', '0', '0', ':', '0', '0', + 'Z', + } + for _, f := range formats { + r := lx.next() + if f == '0' { + if !isDigit(r) { + return lx.errorf("Expected digit in ISO8601 datetime, "+ + "but found %q instead.", r) + } + } else if f != r { + return lx.errorf("Expected %q in ISO8601 datetime, "+ + "but found %q instead.", f, r) + } + } + lx.emit(itemDatetime) + return lx.pop() +} + +// lexNumberStart consumes either an integer or a float. It assumes that +// a negative sign has already been read, but that *no* digits have been +// consumed. lexNumberStart will move to the appropriate integer or float +// states. +func lexNumberStart(lx *lexer) stateFn { + // we MUST see a digit. Even floats have to start with a digit. + r := lx.next() + if !isDigit(r) { + if r == '.' { + return lx.errorf("Floats must start with a digit, not '.'.") + } else { + return lx.errorf("Expected a digit but got %q.", r) + } + } + return lexNumber +} + +// lexNumber consumes an integer or a float after seeing the first digit. +func lexNumber(lx *lexer) stateFn { + r := lx.next() + switch { + case isDigit(r): + return lexNumber + case r == '.': + return lexFloatStart + } + + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexFloatStart starts the consumption of digits of a float after a '.'. +// Namely, at least one digit is required. +func lexFloatStart(lx *lexer) stateFn { + r := lx.next() + if !isDigit(r) { + return lx.errorf("Floats must have a digit after the '.', but got "+ + "%q instead.", r) + } + return lexFloat +} + +// lexFloat consumes the digits of a float after a '.'. +// Assumes that one digit has been consumed after a '.' already. +func lexFloat(lx *lexer) stateFn { + r := lx.next() + if isDigit(r) { + return lexFloat + } + + lx.backup() + lx.emit(itemFloat) + return lx.pop() +} + +// lexConst consumes the s[1:] in s. It assumes that s[0] has already been +// consumed. +func lexConst(lx *lexer, s string) stateFn { + for i := range s[1:] { + if r := lx.next(); r != rune(s[i+1]) { + return lx.errorf("Expected %q, but found %q instead.", s[:i+1], + s[:i]+string(r)) + } + } + return nil +} + +// lexTrue consumes the "rue" in "true". It assumes that 't' has already +// been consumed. +func lexTrue(lx *lexer) stateFn { + if fn := lexConst(lx, "true"); fn != nil { + return fn + } + lx.emit(itemBool) + return lx.pop() +} + +// lexFalse consumes the "alse" in "false". It assumes that 'f' has already +// been consumed. +func lexFalse(lx *lexer) stateFn { + if fn := lexConst(lx, "false"); fn != nil { + return fn + } + lx.emit(itemBool) + return lx.pop() +} + +// lexCommentStart begins the lexing of a comment. It will emit +// itemCommentStart and consume no characters, passing control to lexComment. +func lexCommentStart(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemCommentStart) + return lexComment +} + +// lexComment lexes an entire comment. It assumes that '#' has been consumed. +// It will consume *up to* the first new line character, and pass control +// back to the last state on the stack. +func lexComment(lx *lexer) stateFn { + r := lx.peek() + if isNL(r) || r == eof { + lx.emit(itemText) + return lx.pop() + } + lx.next() + return lexComment +} + +// lexSkip ignores all slurped input and moves on to the next state. +func lexSkip(lx *lexer, nextState stateFn) stateFn { + return func(lx *lexer) stateFn { + lx.ignore() + return nextState + } +} + +// isWhitespace returns true if `r` is a whitespace character according +// to the spec. +func isWhitespace(r rune) bool { + return r == '\t' || r == ' ' +} + +func isNL(r rune) bool { + return r == '\n' || r == '\r' +} + +func isDigit(r rune) bool { + return r >= '0' && r <= '9' +} + +func isHexadecimal(r rune) bool { + return (r >= '0' && r <= '9') || + (r >= 'a' && r <= 'f') || + (r >= 'A' && r <= 'F') +} + +func isBareKeyChar(r rune) bool { + return (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '_' || + r == '-' +} + +func (itype itemType) String() string { + switch itype { + case itemError: + return "Error" + case itemNIL: + return "NIL" + case itemEOF: + return "EOF" + case itemText: + return "Text" + case itemString: + return "String" + case itemRawString: + return "String" + case itemMultilineString: + return "String" + case itemRawMultilineString: + return "String" + case itemBool: + return "Bool" + case itemInteger: + return "Integer" + case itemFloat: + return "Float" + case itemDatetime: + return "DateTime" + case itemTableStart: + return "TableStart" + case itemTableEnd: + return "TableEnd" + case itemKeyStart: + return "KeyStart" + case itemArray: + return "Array" + case itemArrayEnd: + return "ArrayEnd" + case itemCommentStart: + return "CommentStart" + } + panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype))) +} + +func (item item) String() string { + return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/parse.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/parse.go new file mode 100644 index 0000000..c6069be --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/parse.go @@ -0,0 +1,498 @@ +package toml + +import ( + "fmt" + "log" + "strconv" + "strings" + "time" + "unicode" + "unicode/utf8" +) + +type parser struct { + mapping map[string]interface{} + types map[string]tomlType + lx *lexer + + // A list of keys in the order that they appear in the TOML data. + ordered []Key + + // the full key for the current hash in scope + context Key + + // the base key name for everything except hashes + currentKey string + + // rough approximation of line number + approxLine int + + // A map of 'key.group.names' to whether they were created implicitly. + implicits map[string]bool +} + +type parseError string + +func (pe parseError) Error() string { + return string(pe) +} + +func parse(data string) (p *parser, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + if err, ok = r.(parseError); ok { + return + } + panic(r) + } + }() + + p = &parser{ + mapping: make(map[string]interface{}), + types: make(map[string]tomlType), + lx: lex(data), + ordered: make([]Key, 0), + implicits: make(map[string]bool), + } + for { + item := p.next() + if item.typ == itemEOF { + break + } + p.topLevel(item) + } + + return p, nil +} + +func (p *parser) panicf(format string, v ...interface{}) { + msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s", + p.approxLine, p.current(), fmt.Sprintf(format, v...)) + panic(parseError(msg)) +} + +func (p *parser) next() item { + it := p.lx.nextItem() + if it.typ == itemError { + p.panicf("%s", it.val) + } + return it +} + +func (p *parser) bug(format string, v ...interface{}) { + log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...)) +} + +func (p *parser) expect(typ itemType) item { + it := p.next() + p.assertEqual(typ, it.typ) + return it +} + +func (p *parser) assertEqual(expected, got itemType) { + if expected != got { + p.bug("Expected '%s' but got '%s'.", expected, got) + } +} + +func (p *parser) topLevel(item item) { + switch item.typ { + case itemCommentStart: + p.approxLine = item.line + p.expect(itemText) + case itemTableStart: + kg := p.next() + p.approxLine = kg.line + + var key Key + for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() { + key = append(key, p.keyString(kg)) + } + p.assertEqual(itemTableEnd, kg.typ) + + p.establishContext(key, false) + p.setType("", tomlHash) + p.ordered = append(p.ordered, key) + case itemArrayTableStart: + kg := p.next() + p.approxLine = kg.line + + var key Key + for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() { + key = append(key, p.keyString(kg)) + } + p.assertEqual(itemArrayTableEnd, kg.typ) + + p.establishContext(key, true) + p.setType("", tomlArrayHash) + p.ordered = append(p.ordered, key) + case itemKeyStart: + kname := p.next() + p.approxLine = kname.line + p.currentKey = p.keyString(kname) + + val, typ := p.value(p.next()) + p.setValue(p.currentKey, val) + p.setType(p.currentKey, typ) + p.ordered = append(p.ordered, p.context.add(p.currentKey)) + p.currentKey = "" + default: + p.bug("Unexpected type at top level: %s", item.typ) + } +} + +// Gets a string for a key (or part of a key in a table name). +func (p *parser) keyString(it item) string { + switch it.typ { + case itemText: + return it.val + case itemString, itemMultilineString, + itemRawString, itemRawMultilineString: + s, _ := p.value(it) + return s.(string) + default: + p.bug("Unexpected key type: %s", it.typ) + panic("unreachable") + } +} + +// value translates an expected value from the lexer into a Go value wrapped +// as an empty interface. +func (p *parser) value(it item) (interface{}, tomlType) { + switch it.typ { + case itemString: + return p.replaceEscapes(it.val), p.typeOfPrimitive(it) + case itemMultilineString: + trimmed := stripFirstNewline(stripEscapedWhitespace(it.val)) + return p.replaceEscapes(trimmed), p.typeOfPrimitive(it) + case itemRawString: + return it.val, p.typeOfPrimitive(it) + case itemRawMultilineString: + return stripFirstNewline(it.val), p.typeOfPrimitive(it) + case itemBool: + switch it.val { + case "true": + return true, p.typeOfPrimitive(it) + case "false": + return false, p.typeOfPrimitive(it) + } + p.bug("Expected boolean value, but got '%s'.", it.val) + case itemInteger: + num, err := strconv.ParseInt(it.val, 10, 64) + if err != nil { + // See comment below for floats describing why we make a + // distinction between a bug and a user error. + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + + p.panicf("Integer '%s' is out of the range of 64-bit "+ + "signed integers.", it.val) + } else { + p.bug("Expected integer value, but got '%s'.", it.val) + } + } + return num, p.typeOfPrimitive(it) + case itemFloat: + num, err := strconv.ParseFloat(it.val, 64) + if err != nil { + // Distinguish float values. Normally, it'd be a bug if the lexer + // provides an invalid float, but it's possible that the float is + // out of range of valid values (which the lexer cannot determine). + // So mark the former as a bug but the latter as a legitimate user + // error. + // + // This is also true for integers. + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + + p.panicf("Float '%s' is out of the range of 64-bit "+ + "IEEE-754 floating-point numbers.", it.val) + } else { + p.bug("Expected float value, but got '%s'.", it.val) + } + } + return num, p.typeOfPrimitive(it) + case itemDatetime: + t, err := time.Parse("2006-01-02T15:04:05Z", it.val) + if err != nil { + p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val) + } + return t, p.typeOfPrimitive(it) + case itemArray: + array := make([]interface{}, 0) + types := make([]tomlType, 0) + + for it = p.next(); it.typ != itemArrayEnd; it = p.next() { + if it.typ == itemCommentStart { + p.expect(itemText) + continue + } + + val, typ := p.value(it) + array = append(array, val) + types = append(types, typ) + } + return array, p.typeOfArray(types) + } + p.bug("Unexpected value type: %s", it.typ) + panic("unreachable") +} + +// establishContext sets the current context of the parser, +// where the context is either a hash or an array of hashes. Which one is +// set depends on the value of the `array` parameter. +// +// Establishing the context also makes sure that the key isn't a duplicate, and +// will create implicit hashes automatically. +func (p *parser) establishContext(key Key, array bool) { + var ok bool + + // Always start at the top level and drill down for our context. + hashContext := p.mapping + keyContext := make(Key, 0) + + // We only need implicit hashes for key[0:-1] + for _, k := range key[0 : len(key)-1] { + _, ok = hashContext[k] + keyContext = append(keyContext, k) + + // No key? Make an implicit hash and move on. + if !ok { + p.addImplicit(keyContext) + hashContext[k] = make(map[string]interface{}) + } + + // If the hash context is actually an array of tables, then set + // the hash context to the last element in that array. + // + // Otherwise, it better be a table, since this MUST be a key group (by + // virtue of it not being the last element in a key). + switch t := hashContext[k].(type) { + case []map[string]interface{}: + hashContext = t[len(t)-1] + case map[string]interface{}: + hashContext = t + default: + p.panicf("Key '%s' was already created as a hash.", keyContext) + } + } + + p.context = keyContext + if array { + // If this is the first element for this array, then allocate a new + // list of tables for it. + k := key[len(key)-1] + if _, ok := hashContext[k]; !ok { + hashContext[k] = make([]map[string]interface{}, 0, 5) + } + + // Add a new table. But make sure the key hasn't already been used + // for something else. + if hash, ok := hashContext[k].([]map[string]interface{}); ok { + hashContext[k] = append(hash, make(map[string]interface{})) + } else { + p.panicf("Key '%s' was already created and cannot be used as "+ + "an array.", keyContext) + } + } else { + p.setValue(key[len(key)-1], make(map[string]interface{})) + } + p.context = append(p.context, key[len(key)-1]) +} + +// setValue sets the given key to the given value in the current context. +// It will make sure that the key hasn't already been defined, account for +// implicit key groups. +func (p *parser) setValue(key string, value interface{}) { + var tmpHash interface{} + var ok bool + + hash := p.mapping + keyContext := make(Key, 0) + for _, k := range p.context { + keyContext = append(keyContext, k) + if tmpHash, ok = hash[k]; !ok { + p.bug("Context for key '%s' has not been established.", keyContext) + } + switch t := tmpHash.(type) { + case []map[string]interface{}: + // The context is a table of hashes. Pick the most recent table + // defined as the current hash. + hash = t[len(t)-1] + case map[string]interface{}: + hash = t + default: + p.bug("Expected hash to have type 'map[string]interface{}', but "+ + "it has '%T' instead.", tmpHash) + } + } + keyContext = append(keyContext, key) + + if _, ok := hash[key]; ok { + // Typically, if the given key has already been set, then we have + // to raise an error since duplicate keys are disallowed. However, + // it's possible that a key was previously defined implicitly. In this + // case, it is allowed to be redefined concretely. (See the + // `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.) + // + // But we have to make sure to stop marking it as an implicit. (So that + // another redefinition provokes an error.) + // + // Note that since it has already been defined (as a hash), we don't + // want to overwrite it. So our business is done. + if p.isImplicit(keyContext) { + p.removeImplicit(keyContext) + return + } + + // Otherwise, we have a concrete key trying to override a previous + // key, which is *always* wrong. + p.panicf("Key '%s' has already been defined.", keyContext) + } + hash[key] = value +} + +// setType sets the type of a particular value at a given key. +// It should be called immediately AFTER setValue. +// +// Note that if `key` is empty, then the type given will be applied to the +// current context (which is either a table or an array of tables). +func (p *parser) setType(key string, typ tomlType) { + keyContext := make(Key, 0, len(p.context)+1) + for _, k := range p.context { + keyContext = append(keyContext, k) + } + if len(key) > 0 { // allow type setting for hashes + keyContext = append(keyContext, key) + } + p.types[keyContext.String()] = typ +} + +// addImplicit sets the given Key as having been created implicitly. +func (p *parser) addImplicit(key Key) { + p.implicits[key.String()] = true +} + +// removeImplicit stops tagging the given key as having been implicitly +// created. +func (p *parser) removeImplicit(key Key) { + p.implicits[key.String()] = false +} + +// isImplicit returns true if the key group pointed to by the key was created +// implicitly. +func (p *parser) isImplicit(key Key) bool { + return p.implicits[key.String()] +} + +// current returns the full key name of the current context. +func (p *parser) current() string { + if len(p.currentKey) == 0 { + return p.context.String() + } + if len(p.context) == 0 { + return p.currentKey + } + return fmt.Sprintf("%s.%s", p.context, p.currentKey) +} + +func stripFirstNewline(s string) string { + if len(s) == 0 || s[0] != '\n' { + return s + } + return s[1:len(s)] +} + +func stripEscapedWhitespace(s string) string { + esc := strings.Split(s, "\\\n") + if len(esc) > 1 { + for i := 1; i < len(esc); i++ { + esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace) + } + } + return strings.Join(esc, "") +} + +func (p *parser) replaceEscapes(str string) string { + var replaced []rune + s := []byte(str) + r := 0 + for r < len(s) { + if s[r] != '\\' { + c, size := utf8.DecodeRune(s[r:]) + r += size + replaced = append(replaced, c) + continue + } + r += 1 + if r >= len(s) { + p.bug("Escape sequence at end of string.") + return "" + } + switch s[r] { + default: + p.bug("Expected valid escape code after \\, but got %q.", s[r]) + return "" + case 'b': + replaced = append(replaced, rune(0x0008)) + r += 1 + case 't': + replaced = append(replaced, rune(0x0009)) + r += 1 + case 'n': + replaced = append(replaced, rune(0x000A)) + r += 1 + case 'f': + replaced = append(replaced, rune(0x000C)) + r += 1 + case 'r': + replaced = append(replaced, rune(0x000D)) + r += 1 + case '"': + replaced = append(replaced, rune(0x0022)) + r += 1 + case '\\': + replaced = append(replaced, rune(0x005C)) + r += 1 + case 'u': + // At this point, we know we have a Unicode escape of the form + // `uXXXX` at [r, r+5). (Because the lexer guarantees this + // for us.) + escaped := p.asciiEscapeToUnicode(s[r+1 : r+5]) + replaced = append(replaced, escaped) + r += 5 + case 'U': + // At this point, we know we have a Unicode escape of the form + // `uXXXX` at [r, r+9). (Because the lexer guarantees this + // for us.) + escaped := p.asciiEscapeToUnicode(s[r+1 : r+9]) + replaced = append(replaced, escaped) + r += 9 + } + } + return string(replaced) +} + +func (p *parser) asciiEscapeToUnicode(bs []byte) rune { + s := string(bs) + hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32) + if err != nil { + p.bug("Could not parse '%s' as a hexadecimal number, but the "+ + "lexer claims it's OK: %s", s, err) + } + + // BUG(burntsushi) + // I honestly don't understand how this works. I can't seem + // to find a way to make this fail. I figured this would fail on invalid + // UTF-8 characters like U+DCFF, but it doesn't. + if !utf8.ValidString(string(rune(hex))) { + p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s) + } + return rune(hex) +} + +func isStringType(ty itemType) bool { + return ty == itemString || ty == itemMultilineString || + ty == itemRawString || ty == itemRawMultilineString +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_check.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_check.go new file mode 100644 index 0000000..c73f8af --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_check.go @@ -0,0 +1,91 @@ +package toml + +// tomlType represents any Go type that corresponds to a TOML type. +// While the first draft of the TOML spec has a simplistic type system that +// probably doesn't need this level of sophistication, we seem to be militating +// toward adding real composite types. +type tomlType interface { + typeString() string +} + +// typeEqual accepts any two types and returns true if they are equal. +func typeEqual(t1, t2 tomlType) bool { + if t1 == nil || t2 == nil { + return false + } + return t1.typeString() == t2.typeString() +} + +func typeIsHash(t tomlType) bool { + return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash) +} + +type tomlBaseType string + +func (btype tomlBaseType) typeString() string { + return string(btype) +} + +func (btype tomlBaseType) String() string { + return btype.typeString() +} + +var ( + tomlInteger tomlBaseType = "Integer" + tomlFloat tomlBaseType = "Float" + tomlDatetime tomlBaseType = "Datetime" + tomlString tomlBaseType = "String" + tomlBool tomlBaseType = "Bool" + tomlArray tomlBaseType = "Array" + tomlHash tomlBaseType = "Hash" + tomlArrayHash tomlBaseType = "ArrayHash" +) + +// typeOfPrimitive returns a tomlType of any primitive value in TOML. +// Primitive values are: Integer, Float, Datetime, String and Bool. +// +// Passing a lexer item other than the following will cause a BUG message +// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime. +func (p *parser) typeOfPrimitive(lexItem item) tomlType { + switch lexItem.typ { + case itemInteger: + return tomlInteger + case itemFloat: + return tomlFloat + case itemDatetime: + return tomlDatetime + case itemString: + return tomlString + case itemMultilineString: + return tomlString + case itemRawString: + return tomlString + case itemRawMultilineString: + return tomlString + case itemBool: + return tomlBool + } + p.bug("Cannot infer primitive type of lex item '%s'.", lexItem) + panic("unreachable") +} + +// typeOfArray returns a tomlType for an array given a list of types of its +// values. +// +// In the current spec, if an array is homogeneous, then its type is always +// "Array". If the array is not homogeneous, an error is generated. +func (p *parser) typeOfArray(types []tomlType) tomlType { + // Empty arrays are cool. + if len(types) == 0 { + return tomlArray + } + + theType := types[0] + for _, t := range types[1:] { + if !typeEqual(theType, t) { + p.panicf("Array contains values of type '%s' and '%s', but "+ + "arrays must be homogeneous.", theType, t) + } + } + return tomlArray +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_fields.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_fields.go new file mode 100644 index 0000000..7592f87 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/BurntSushi/toml/type_fields.go @@ -0,0 +1,241 @@ +package toml + +// Struct field handling is adapted from code in encoding/json: +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the Go distribution. + +import ( + "reflect" + "sort" + "sync" +) + +// A field represents a single field found in a struct. +type field struct { + name string // the name of the field (`toml` tag included) + tag bool // whether field has a `toml` tag + index []int // represents the depth of an anonymous field + typ reflect.Type // the type of the field +} + +// byName sorts field by name, breaking ties with depth, +// then breaking ties with "name came from toml tag", then +// breaking ties with index sequence. +type byName []field + +func (x byName) Len() int { return len(x) } + +func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byName) Less(i, j int) bool { + if x[i].name != x[j].name { + return x[i].name < x[j].name + } + if len(x[i].index) != len(x[j].index) { + return len(x[i].index) < len(x[j].index) + } + if x[i].tag != x[j].tag { + return x[i].tag + } + return byIndex(x).Less(i, j) +} + +// byIndex sorts field by index sequence. +type byIndex []field + +func (x byIndex) Len() int { return len(x) } + +func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byIndex) Less(i, j int) bool { + for k, xik := range x[i].index { + if k >= len(x[j].index) { + return false + } + if xik != x[j].index[k] { + return xik < x[j].index[k] + } + } + return len(x[i].index) < len(x[j].index) +} + +// typeFields returns a list of fields that TOML should recognize for the given +// type. The algorithm is breadth-first search over the set of structs to +// include - the top struct and then any reachable anonymous structs. +func typeFields(t reflect.Type) []field { + // Anonymous fields to explore at the current level and the next. + current := []field{} + next := []field{{typ: t}} + + // Count of queued names for current level and the next. + count := map[reflect.Type]int{} + nextCount := map[reflect.Type]int{} + + // Types already visited at an earlier level. + visited := map[reflect.Type]bool{} + + // Fields found. + var fields []field + + for len(next) > 0 { + current, next = next, current[:0] + count, nextCount = nextCount, map[reflect.Type]int{} + + for _, f := range current { + if visited[f.typ] { + continue + } + visited[f.typ] = true + + // Scan f.typ for fields to include. + for i := 0; i < f.typ.NumField(); i++ { + sf := f.typ.Field(i) + if sf.PkgPath != "" { // unexported + continue + } + name := sf.Tag.Get("toml") + if name == "-" { + continue + } + index := make([]int, len(f.index)+1) + copy(index, f.index) + index[len(f.index)] = i + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + // Follow pointer. + ft = ft.Elem() + } + + // Record found field and index sequence. + if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { + tagged := name != "" + if name == "" { + name = sf.Name + } + fields = append(fields, field{name, tagged, index, ft}) + if count[f.typ] > 1 { + // If there were multiple instances, add a second, + // so that the annihilation code will see a duplicate. + // It only cares about the distinction between 1 or 2, + // so don't bother generating any more copies. + fields = append(fields, fields[len(fields)-1]) + } + continue + } + + // Record new anonymous struct to explore in next round. + nextCount[ft]++ + if nextCount[ft] == 1 { + f := field{name: ft.Name(), index: index, typ: ft} + next = append(next, f) + } + } + } + } + + sort.Sort(byName(fields)) + + // Delete all fields that are hidden by the Go rules for embedded fields, + // except that fields with TOML tags are promoted. + + // The fields are sorted in primary order of name, secondary order + // of field index length. Loop over names; for each name, delete + // hidden fields by choosing the one dominant field that survives. + out := fields[:0] + for advance, i := 0, 0; i < len(fields); i += advance { + // One iteration per name. + // Find the sequence of fields with the name of this first field. + fi := fields[i] + name := fi.name + for advance = 1; i+advance < len(fields); advance++ { + fj := fields[i+advance] + if fj.name != name { + break + } + } + if advance == 1 { // Only one field with this name + out = append(out, fi) + continue + } + dominant, ok := dominantField(fields[i : i+advance]) + if ok { + out = append(out, dominant) + } + } + + fields = out + sort.Sort(byIndex(fields)) + + return fields +} + +// dominantField looks through the fields, all of which are known to +// have the same name, to find the single field that dominates the +// others using Go's embedding rules, modified by the presence of +// TOML tags. If there are multiple top-level fields, the boolean +// will be false: This condition is an error in Go and we skip all +// the fields. +func dominantField(fields []field) (field, bool) { + // The fields are sorted in increasing index-length order. The winner + // must therefore be one with the shortest index length. Drop all + // longer entries, which is easy: just truncate the slice. + length := len(fields[0].index) + tagged := -1 // Index of first tagged field. + for i, f := range fields { + if len(f.index) > length { + fields = fields[:i] + break + } + if f.tag { + if tagged >= 0 { + // Multiple tagged fields at the same level: conflict. + // Return no field. + return field{}, false + } + tagged = i + } + } + if tagged >= 0 { + return fields[tagged], true + } + // All remaining fields have the same length. If there's more than one, + // we have a conflict (two fields named "X" at the same level) and we + // return no field. + if len(fields) > 1 { + return field{}, false + } + return fields[0], true +} + +var fieldCache struct { + sync.RWMutex + m map[reflect.Type][]field +} + +// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. +func cachedTypeFields(t reflect.Type) []field { + fieldCache.RLock() + f := fieldCache.m[t] + fieldCache.RUnlock() + if f != nil { + return f + } + + // Compute fields without lock. + // Might duplicate effort but won't hold other computations back. + f = typeFields(t) + if f == nil { + f = []field{} + } + + fieldCache.Lock() + if fieldCache.m == nil { + fieldCache.m = map[reflect.Type][]field{} + } + fieldCache.m[t] = f + fieldCache.Unlock() + return f +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/LICENSE new file mode 100644 index 0000000..14e2f77 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/appengine.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/appengine.go new file mode 100644 index 0000000..565614e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/appengine.go @@ -0,0 +1,19 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build appengine + +package mysql + +import ( + "appengine/cloudsql" +) + +func init() { + RegisterDial("cloudsql", cloudsql.Dial) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/buffer.go new file mode 100644 index 0000000..2001fea --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -0,0 +1,147 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "io" + "net" + "time" +) + +const defaultBufSize = 4096 + +// A buffer which is used for both reading and writing. +// This is possible since communication on each connection is synchronous. +// In other words, we can't write and read simultaneously on the same connection. +// The buffer is similar to bufio.Reader / Writer but zero-copy-ish +// Also highly optimized for this particular use case. +type buffer struct { + buf []byte + nc net.Conn + idx int + length int + timeout time.Duration +} + +func newBuffer(nc net.Conn) buffer { + var b [defaultBufSize]byte + return buffer{ + buf: b[:], + nc: nc, + } +} + +// fill reads into the buffer until at least _need_ bytes are in it +func (b *buffer) fill(need int) error { + n := b.length + + // move existing data to the beginning + if n > 0 && b.idx > 0 { + copy(b.buf[0:n], b.buf[b.idx:]) + } + + // grow buffer if necessary + // TODO: let the buffer shrink again at some point + // Maybe keep the org buf slice and swap back? + if need > len(b.buf) { + // Round up to the next multiple of the default size + newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + copy(newBuf, b.buf) + b.buf = newBuf + } + + b.idx = 0 + + for { + if b.timeout > 0 { + if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + return err + } + } + + nn, err := b.nc.Read(b.buf[n:]) + n += nn + + switch err { + case nil: + if n < need { + continue + } + b.length = n + return nil + + case io.EOF: + if n >= need { + b.length = n + return nil + } + return io.ErrUnexpectedEOF + + default: + return err + } + } +} + +// returns next N bytes from buffer. +// The returned slice is only guaranteed to be valid until the next read +func (b *buffer) readNext(need int) ([]byte, error) { + if b.length < need { + // refill + if err := b.fill(need); err != nil { + return nil, err + } + } + + offset := b.idx + b.idx += need + b.length -= need + return b.buf[offset:b.idx], nil +} + +// returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeBuffer(length int) []byte { + if b.length > 0 { + return nil + } + + // test (cheap) general case first + if length <= defaultBufSize || length <= cap(b.buf) { + return b.buf[:length] + } + + if length < maxPacketSize { + b.buf = make([]byte, length) + return b.buf + } + return make([]byte, length) +} + +// shortcut which can be used if the requested buffer is guaranteed to be +// smaller than defaultBufSize +// Only one buffer (total) can be used at a time. +func (b *buffer) takeSmallBuffer(length int) []byte { + if b.length == 0 { + return b.buf[:length] + } + return nil +} + +// takeCompleteBuffer returns the complete existing buffer. +// This can be used if the necessary buffer size is unknown. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeCompleteBuffer() []byte { + if b.length == 0 { + return b.buf + } + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/collations.go new file mode 100644 index 0000000..82079cf --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/collations.go @@ -0,0 +1,250 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +const defaultCollation = "utf8_general_ci" + +// A list of available collations mapped to the internal ID. +// To update this map use the following MySQL query: +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS +var collations = map[string]byte{ + "big5_chinese_ci": 1, + "latin2_czech_cs": 2, + "dec8_swedish_ci": 3, + "cp850_general_ci": 4, + "latin1_german1_ci": 5, + "hp8_english_ci": 6, + "koi8r_general_ci": 7, + "latin1_swedish_ci": 8, + "latin2_general_ci": 9, + "swe7_swedish_ci": 10, + "ascii_general_ci": 11, + "ujis_japanese_ci": 12, + "sjis_japanese_ci": 13, + "cp1251_bulgarian_ci": 14, + "latin1_danish_ci": 15, + "hebrew_general_ci": 16, + "tis620_thai_ci": 18, + "euckr_korean_ci": 19, + "latin7_estonian_cs": 20, + "latin2_hungarian_ci": 21, + "koi8u_general_ci": 22, + "cp1251_ukrainian_ci": 23, + "gb2312_chinese_ci": 24, + "greek_general_ci": 25, + "cp1250_general_ci": 26, + "latin2_croatian_ci": 27, + "gbk_chinese_ci": 28, + "cp1257_lithuanian_ci": 29, + "latin5_turkish_ci": 30, + "latin1_german2_ci": 31, + "armscii8_general_ci": 32, + "utf8_general_ci": 33, + "cp1250_czech_cs": 34, + "ucs2_general_ci": 35, + "cp866_general_ci": 36, + "keybcs2_general_ci": 37, + "macce_general_ci": 38, + "macroman_general_ci": 39, + "cp852_general_ci": 40, + "latin7_general_ci": 41, + "latin7_general_cs": 42, + "macce_bin": 43, + "cp1250_croatian_ci": 44, + "utf8mb4_general_ci": 45, + "utf8mb4_bin": 46, + "latin1_bin": 47, + "latin1_general_ci": 48, + "latin1_general_cs": 49, + "cp1251_bin": 50, + "cp1251_general_ci": 51, + "cp1251_general_cs": 52, + "macroman_bin": 53, + "utf16_general_ci": 54, + "utf16_bin": 55, + "utf16le_general_ci": 56, + "cp1256_general_ci": 57, + "cp1257_bin": 58, + "cp1257_general_ci": 59, + "utf32_general_ci": 60, + "utf32_bin": 61, + "utf16le_bin": 62, + "binary": 63, + "armscii8_bin": 64, + "ascii_bin": 65, + "cp1250_bin": 66, + "cp1256_bin": 67, + "cp866_bin": 68, + "dec8_bin": 69, + "greek_bin": 70, + "hebrew_bin": 71, + "hp8_bin": 72, + "keybcs2_bin": 73, + "koi8r_bin": 74, + "koi8u_bin": 75, + "latin2_bin": 77, + "latin5_bin": 78, + "latin7_bin": 79, + "cp850_bin": 80, + "cp852_bin": 81, + "swe7_bin": 82, + "utf8_bin": 83, + "big5_bin": 84, + "euckr_bin": 85, + "gb2312_bin": 86, + "gbk_bin": 87, + "sjis_bin": 88, + "tis620_bin": 89, + "ucs2_bin": 90, + "ujis_bin": 91, + "geostd8_general_ci": 92, + "geostd8_bin": 93, + "latin1_spanish_ci": 94, + "cp932_japanese_ci": 95, + "cp932_bin": 96, + "eucjpms_japanese_ci": 97, + "eucjpms_bin": 98, + "cp1250_polish_ci": 99, + "utf16_unicode_ci": 101, + "utf16_icelandic_ci": 102, + "utf16_latvian_ci": 103, + "utf16_romanian_ci": 104, + "utf16_slovenian_ci": 105, + "utf16_polish_ci": 106, + "utf16_estonian_ci": 107, + "utf16_spanish_ci": 108, + "utf16_swedish_ci": 109, + "utf16_turkish_ci": 110, + "utf16_czech_ci": 111, + "utf16_danish_ci": 112, + "utf16_lithuanian_ci": 113, + "utf16_slovak_ci": 114, + "utf16_spanish2_ci": 115, + "utf16_roman_ci": 116, + "utf16_persian_ci": 117, + "utf16_esperanto_ci": 118, + "utf16_hungarian_ci": 119, + "utf16_sinhala_ci": 120, + "utf16_german2_ci": 121, + "utf16_croatian_ci": 122, + "utf16_unicode_520_ci": 123, + "utf16_vietnamese_ci": 124, + "ucs2_unicode_ci": 128, + "ucs2_icelandic_ci": 129, + "ucs2_latvian_ci": 130, + "ucs2_romanian_ci": 131, + "ucs2_slovenian_ci": 132, + "ucs2_polish_ci": 133, + "ucs2_estonian_ci": 134, + "ucs2_spanish_ci": 135, + "ucs2_swedish_ci": 136, + "ucs2_turkish_ci": 137, + "ucs2_czech_ci": 138, + "ucs2_danish_ci": 139, + "ucs2_lithuanian_ci": 140, + "ucs2_slovak_ci": 141, + "ucs2_spanish2_ci": 142, + "ucs2_roman_ci": 143, + "ucs2_persian_ci": 144, + "ucs2_esperanto_ci": 145, + "ucs2_hungarian_ci": 146, + "ucs2_sinhala_ci": 147, + "ucs2_german2_ci": 148, + "ucs2_croatian_ci": 149, + "ucs2_unicode_520_ci": 150, + "ucs2_vietnamese_ci": 151, + "ucs2_general_mysql500_ci": 159, + "utf32_unicode_ci": 160, + "utf32_icelandic_ci": 161, + "utf32_latvian_ci": 162, + "utf32_romanian_ci": 163, + "utf32_slovenian_ci": 164, + "utf32_polish_ci": 165, + "utf32_estonian_ci": 166, + "utf32_spanish_ci": 167, + "utf32_swedish_ci": 168, + "utf32_turkish_ci": 169, + "utf32_czech_ci": 170, + "utf32_danish_ci": 171, + "utf32_lithuanian_ci": 172, + "utf32_slovak_ci": 173, + "utf32_spanish2_ci": 174, + "utf32_roman_ci": 175, + "utf32_persian_ci": 176, + "utf32_esperanto_ci": 177, + "utf32_hungarian_ci": 178, + "utf32_sinhala_ci": 179, + "utf32_german2_ci": 180, + "utf32_croatian_ci": 181, + "utf32_unicode_520_ci": 182, + "utf32_vietnamese_ci": 183, + "utf8_unicode_ci": 192, + "utf8_icelandic_ci": 193, + "utf8_latvian_ci": 194, + "utf8_romanian_ci": 195, + "utf8_slovenian_ci": 196, + "utf8_polish_ci": 197, + "utf8_estonian_ci": 198, + "utf8_spanish_ci": 199, + "utf8_swedish_ci": 200, + "utf8_turkish_ci": 201, + "utf8_czech_ci": 202, + "utf8_danish_ci": 203, + "utf8_lithuanian_ci": 204, + "utf8_slovak_ci": 205, + "utf8_spanish2_ci": 206, + "utf8_roman_ci": 207, + "utf8_persian_ci": 208, + "utf8_esperanto_ci": 209, + "utf8_hungarian_ci": 210, + "utf8_sinhala_ci": 211, + "utf8_german2_ci": 212, + "utf8_croatian_ci": 213, + "utf8_unicode_520_ci": 214, + "utf8_vietnamese_ci": 215, + "utf8_general_mysql500_ci": 223, + "utf8mb4_unicode_ci": 224, + "utf8mb4_icelandic_ci": 225, + "utf8mb4_latvian_ci": 226, + "utf8mb4_romanian_ci": 227, + "utf8mb4_slovenian_ci": 228, + "utf8mb4_polish_ci": 229, + "utf8mb4_estonian_ci": 230, + "utf8mb4_spanish_ci": 231, + "utf8mb4_swedish_ci": 232, + "utf8mb4_turkish_ci": 233, + "utf8mb4_czech_ci": 234, + "utf8mb4_danish_ci": 235, + "utf8mb4_lithuanian_ci": 236, + "utf8mb4_slovak_ci": 237, + "utf8mb4_spanish2_ci": 238, + "utf8mb4_roman_ci": 239, + "utf8mb4_persian_ci": 240, + "utf8mb4_esperanto_ci": 241, + "utf8mb4_hungarian_ci": 242, + "utf8mb4_sinhala_ci": 243, + "utf8mb4_german2_ci": 244, + "utf8mb4_croatian_ci": 245, + "utf8mb4_unicode_520_ci": 246, + "utf8mb4_vietnamese_ci": 247, +} + +// A blacklist of collations which is unsafe to interpolate parameters. +// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. +var unsafeCollations = map[string]bool{ + "big5_chinese_ci": true, + "sjis_japanese_ci": true, + "gbk_chinese_ci": true, + "big5_bin": true, + "gb2312_bin": true, + "gbk_bin": true, + "sjis_bin": true, + "cp932_japanese_ci": true, + "cp932_bin": true, +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/connection.go new file mode 100644 index 0000000..c3899de --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/connection.go @@ -0,0 +1,372 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "net" + "strconv" + "strings" + "time" +) + +type mysqlConn struct { + buf buffer + netConn net.Conn + affectedRows uint64 + insertId uint64 + cfg *Config + maxPacketAllowed int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + parseTime bool + strict bool +} + +// Handles parameters set in DSN after the connection is established +func (mc *mysqlConn) handleParams() (err error) { + for param, val := range mc.cfg.Params { + switch param { + // Charset + case "charset": + charsets := strings.Split(val, ",") + for i := range charsets { + // ignore errors here - a charset may not exist + err = mc.exec("SET NAMES " + charsets[i]) + if err == nil { + break + } + } + if err != nil { + return + } + + // System Vars + default: + err = mc.exec("SET " + param + "=" + val + "") + if err != nil { + return + } + } + } + + return +} + +func (mc *mysqlConn) Begin() (driver.Tx, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + err := mc.exec("START TRANSACTION") + if err == nil { + return &mysqlTx{mc}, err + } + + return nil, err +} + +func (mc *mysqlConn) Close() (err error) { + // Makes Close idempotent + if mc.netConn != nil { + err = mc.writeCommandPacket(comQuit) + } + + mc.cleanup() + + return +} + +// Closes the network connection and unsets internal variables. Do not call this +// function after successfully authentication, call Close instead. This function +// is called before auth or on auth failure because MySQL will have already +// closed the network connection. +func (mc *mysqlConn) cleanup() { + // Makes cleanup idempotent + if mc.netConn != nil { + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } + mc.netConn = nil + } + mc.cfg = nil + mc.buf.nc = nil +} + +func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := mc.writeCommandPacketStr(comStmtPrepare, query) + if err != nil { + return nil, err + } + + stmt := &mysqlStmt{ + mc: mc, + } + + // Read Result + columnCount, err := stmt.readPrepareResultPacket() + if err == nil { + if stmt.paramCount > 0 { + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + } + + if columnCount > 0 { + err = mc.readUntilEOF() + } + } + + return stmt, err +} + +func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { + buf := mc.buf.takeCompleteBuffer() + if buf == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return "", driver.ErrBadConn + } + buf = buf[:0] + argPos := 0 + + for i := 0; i < len(query); i++ { + q := strings.IndexByte(query[i:], '?') + if q == -1 { + buf = append(buf, query[i:]...) + break + } + buf = append(buf, query[i:i+q]...) + i += q + + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + continue + } + + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + v := v.In(mc.cfg.Loc) + v = v.Add(time.Nanosecond * 500) // To round under microsecond + year := v.Year() + year100 := year / 100 + year1 := year % 100 + month := v.Month() + day := v.Day() + hour := v.Hour() + minute := v.Minute() + second := v.Second() + micro := v.Nanosecond() / 1000 + + buf = append(buf, []byte{ + '\'', + digits10[year100], digits01[year100], + digits10[year1], digits01[year1], + '-', + digits10[month], digits01[month], + '-', + digits10[day], digits01[day], + ' ', + digits10[hour], digits01[hour], + ':', + digits10[minute], digits01[minute], + ':', + digits10[second], digits01[second], + }...) + + if micro != 0 { + micro10000 := micro / 10000 + micro100 := micro / 100 % 100 + micro1 := micro % 100 + buf = append(buf, []byte{ + '.', + digits10[micro10000], digits01[micro10000], + digits10[micro100], digits01[micro100], + digits10[micro1], digits01[micro1], + }...) + } + buf = append(buf, '\'') + } + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeStringBackslash(buf, v) + } else { + buf = escapeStringQuotes(buf, v) + } + buf = append(buf, '\'') + default: + return "", driver.ErrSkip + } + + if len(buf)+4 > mc.maxPacketAllowed { + return "", driver.ErrSkip + } + } + if argPos != len(args) { + return "", driver.ErrSkip + } + return string(buf), nil +} + +func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil + } + mc.affectedRows = 0 + mc.insertId = 0 + + err := mc.exec(query) + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, err + } + return nil, err +} + +// Internal function to execute commands +func (mc *mysqlConn) exec(query string) error { + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err != nil { + return err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil && resLen > 0 { + if err = mc.readUntilEOF(); err != nil { + return err + } + + err = mc.readUntilEOF() + } + + return err +} + +func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce roundtrip + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil + } + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err == nil { + // Read Result + var resLen int + resLen, err = mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + + if resLen == 0 { + // no columns, no more data + return emptyRows{}, nil + } + // Columns + rows.columns, err = mc.readColumns(resLen) + return rows, err + } + } + return nil, err +} + +// Gets the value of the given MySQL System Variable +// The returned byte slice is only valid until the next read +func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { + // Send command + if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + return nil, err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + + if resLen > 0 { + // Columns + if err := mc.readUntilEOF(); err != nil { + return nil, err + } + } + + dest := make([]driver.Value, resLen) + if err = rows.readRow(dest); err == nil { + return dest[0].([]byte), mc.readUntilEOF() + } + } + return nil, err +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/const.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/const.go new file mode 100644 index 0000000..88cfff3 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/const.go @@ -0,0 +1,163 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +const ( + minProtocolVersion byte = 10 + maxPacketSize = 1<<24 - 1 + timeFormat = "2006-01-02 15:04:05.999999" +) + +// MySQL constants documentation: +// http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +const ( + iOK byte = 0x00 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff +) + +// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags +type clientFlag uint32 + +const ( + clientLongPassword clientFlag = 1 << iota + clientFoundRows + clientLongFlag + clientConnectWithDB + clientNoSchema + clientCompress + clientODBC + clientLocalFiles + clientIgnoreSpace + clientProtocol41 + clientInteractive + clientSSL + clientIgnoreSIGPIPE + clientTransactions + clientReserved + clientSecureConn + clientMultiStatements + clientMultiResults + clientPSMultiResults + clientPluginAuth + clientConnectAttrs + clientPluginAuthLenEncClientData + clientCanHandleExpiredPasswords + clientSessionTrack + clientDeprecateEOF +) + +const ( + comQuit byte = iota + 1 + comInitDB + comQuery + comFieldList + comCreateDB + comDropDB + comRefresh + comShutdown + comStatistics + comProcessInfo + comConnect + comProcessKill + comDebug + comPing + comTime + comDelayedInsert + comChangeUser + comBinlogDump + comTableDump + comConnectOut + comRegisterSlave + comStmtPrepare + comStmtExecute + comStmtSendLongData + comStmtClose + comStmtReset + comSetOption + comStmtFetch +) + +// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +const ( + fieldTypeDecimal byte = iota + fieldTypeTiny + fieldTypeShort + fieldTypeLong + fieldTypeFloat + fieldTypeDouble + fieldTypeNULL + fieldTypeTimestamp + fieldTypeLongLong + fieldTypeInt24 + fieldTypeDate + fieldTypeTime + fieldTypeDateTime + fieldTypeYear + fieldTypeNewDate + fieldTypeVarChar + fieldTypeBit +) +const ( + fieldTypeJSON byte = iota + 0xf5 + fieldTypeNewDecimal + fieldTypeEnum + fieldTypeSet + fieldTypeTinyBLOB + fieldTypeMediumBLOB + fieldTypeLongBLOB + fieldTypeBLOB + fieldTypeVarString + fieldTypeString + fieldTypeGeometry +) + +type fieldFlag uint16 + +const ( + flagNotNULL fieldFlag = 1 << iota + flagPriKey + flagUniqueKey + flagMultipleKey + flagBLOB + flagUnsigned + flagZeroFill + flagBinary + flagEnum + flagAutoIncrement + flagTimestamp + flagSet + flagUnknown1 + flagUnknown2 + flagUnknown3 + flagUnknown4 +) + +// http://dev.mysql.com/doc/internals/en/status-flags.html +type statusFlag uint16 + +const ( + statusInTrans statusFlag = 1 << iota + statusInAutocommit + statusReserved // Not in documentation + statusMoreResultsExists + statusNoGoodIndexUsed + statusNoIndexUsed + statusCursorExists + statusLastRowSent + statusDbDropped + statusNoBackslashEscapes + statusMetadataChanged + statusQueryWasSlow + statusPsOutParams + statusInTransReadonly + statusSessionStateChanged +) diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/driver.go new file mode 100644 index 0000000..899f955 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/driver.go @@ -0,0 +1,167 @@ +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package mysql provides a MySQL driver for Go's database/sql package +// +// The driver should be used via the database/sql package: +// +// import "database/sql" +// import _ "github.com/go-sql-driver/mysql" +// +// db, err := sql.Open("mysql", "user:password@/dbname") +// +// See https://github.com/go-sql-driver/mysql#usage for details +package mysql + +import ( + "database/sql" + "database/sql/driver" + "net" +) + +// MySQLDriver is exported to make the driver directly accessible. +// In general the driver is used via the database/sql package. +type MySQLDriver struct{} + +// DialFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDial +type DialFunc func(addr string) (net.Conn, error) + +var dials map[string]DialFunc + +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +func RegisterDial(net string, dial DialFunc) { + if dials == nil { + dials = make(map[string]DialFunc) + } + dials[net] = dial +} + +// Open new Connection. +// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how +// the DSN string is formated +func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxPacketAllowed: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + } + mc.cfg, err = ParseDSN(dsn) + if err != nil { + return nil, err + } + mc.parseTime = mc.cfg.ParseTime + mc.strict = mc.cfg.Strict + + // Connect to Server + if dial, ok := dials[mc.cfg.Net]; ok { + mc.netConn, err = dial(mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + } + if err != nil { + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + cipher, err := mc.readInitPacket() + if err != nil { + mc.cleanup() + return nil, err + } + + // Send Client Authentication Packet + if err = mc.writeAuthPacket(cipher); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = handleAuthResult(mc, cipher); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxPacketAllowed = stringToInt(maxap) - 1 + if mc.maxPacketAllowed < maxPacketSize { + mc.maxWriteSize = mc.maxPacketAllowed + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +func handleAuthResult(mc *mysqlConn, cipher []byte) error { + // Read Result Packet + err := mc.readResultOK() + if err == nil { + return nil // auth successful + } + + if mc.cfg == nil { + return err // auth failed and retry not possible + } + + // Retry auth if configured to do so. + if mc.cfg.AllowOldPasswords && err == ErrOldPassword { + // Retry with old authentication method. Note: there are edge cases + // where this should work but doesn't; this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + if err = mc.writeOldAuthPacket(cipher); err != nil { + return err + } + err = mc.readResultOK() + } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { + // Retry with clear text password for + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + if err = mc.writeClearAuthPacket(); err != nil { + return err + } + err = mc.readResultOK() + } + return err +} + +func init() { + sql.Register("mysql", &MySQLDriver{}) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/dsn.go new file mode 100644 index 0000000..73138bc --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -0,0 +1,513 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "strings" + "time" +) + +var ( + errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") +) + +// Config is a configuration parsed from a DSN string +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE + AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowOldPasswords bool // Allows the old insecure password method + ClientFoundRows bool // Return number of matching rows instead of rows changed + ColumnsWithAlias bool // Prepend table alias to column names + InterpolateParams bool // Interpolate placeholders into query string + MultiStatements bool // Allow multiple statements in one query + ParseTime bool // Parse time values to time.Time + Strict bool // Return warnings as errors +} + +// FormatDSN formats the given Config into a DSN string which can be passed to +// the driver. +func (cfg *Config) FormatDSN() string { + var buf bytes.Buffer + + // [username[:password]@] + if len(cfg.User) > 0 { + buf.WriteString(cfg.User) + if len(cfg.Passwd) > 0 { + buf.WriteByte(':') + buf.WriteString(cfg.Passwd) + } + buf.WriteByte('@') + } + + // [protocol[(address)]] + if len(cfg.Net) > 0 { + buf.WriteString(cfg.Net) + if len(cfg.Addr) > 0 { + buf.WriteByte('(') + buf.WriteString(cfg.Addr) + buf.WriteByte(')') + } + } + + // /dbname + buf.WriteByte('/') + buf.WriteString(cfg.DBName) + + // [?param1=value1&...¶mN=valueN] + hasParam := false + + if cfg.AllowAllFiles { + hasParam = true + buf.WriteString("?allowAllFiles=true") + } + + if cfg.AllowCleartextPasswords { + if hasParam { + buf.WriteString("&allowCleartextPasswords=true") + } else { + hasParam = true + buf.WriteString("?allowCleartextPasswords=true") + } + } + + if cfg.AllowOldPasswords { + if hasParam { + buf.WriteString("&allowOldPasswords=true") + } else { + hasParam = true + buf.WriteString("?allowOldPasswords=true") + } + } + + if cfg.ClientFoundRows { + if hasParam { + buf.WriteString("&clientFoundRows=true") + } else { + hasParam = true + buf.WriteString("?clientFoundRows=true") + } + } + + if col := cfg.Collation; col != defaultCollation && len(col) > 0 { + if hasParam { + buf.WriteString("&collation=") + } else { + hasParam = true + buf.WriteString("?collation=") + } + buf.WriteString(col) + } + + if cfg.ColumnsWithAlias { + if hasParam { + buf.WriteString("&columnsWithAlias=true") + } else { + hasParam = true + buf.WriteString("?columnsWithAlias=true") + } + } + + if cfg.InterpolateParams { + if hasParam { + buf.WriteString("&interpolateParams=true") + } else { + hasParam = true + buf.WriteString("?interpolateParams=true") + } + } + + if cfg.Loc != time.UTC && cfg.Loc != nil { + if hasParam { + buf.WriteString("&loc=") + } else { + hasParam = true + buf.WriteString("?loc=") + } + buf.WriteString(url.QueryEscape(cfg.Loc.String())) + } + + if cfg.MultiStatements { + if hasParam { + buf.WriteString("&multiStatements=true") + } else { + hasParam = true + buf.WriteString("?multiStatements=true") + } + } + + if cfg.ParseTime { + if hasParam { + buf.WriteString("&parseTime=true") + } else { + hasParam = true + buf.WriteString("?parseTime=true") + } + } + + if cfg.ReadTimeout > 0 { + if hasParam { + buf.WriteString("&readTimeout=") + } else { + hasParam = true + buf.WriteString("?readTimeout=") + } + buf.WriteString(cfg.ReadTimeout.String()) + } + + if cfg.Strict { + if hasParam { + buf.WriteString("&strict=true") + } else { + hasParam = true + buf.WriteString("?strict=true") + } + } + + if cfg.Timeout > 0 { + if hasParam { + buf.WriteString("&timeout=") + } else { + hasParam = true + buf.WriteString("?timeout=") + } + buf.WriteString(cfg.Timeout.String()) + } + + if len(cfg.TLSConfig) > 0 { + if hasParam { + buf.WriteString("&tls=") + } else { + hasParam = true + buf.WriteString("?tls=") + } + buf.WriteString(url.QueryEscape(cfg.TLSConfig)) + } + + if cfg.WriteTimeout > 0 { + if hasParam { + buf.WriteString("&writeTimeout=") + } else { + hasParam = true + buf.WriteString("?writeTimeout=") + } + buf.WriteString(cfg.WriteTimeout.String()) + } + + // other params + if cfg.Params != nil { + for param, value := range cfg.Params { + if hasParam { + buf.WriteByte('&') + } else { + hasParam = true + buf.WriteByte('?') + } + + buf.WriteString(param) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(value)) + } + } + + return buf.String() +} + +// ParseDSN parses the DSN string to a Config +func ParseDSN(dsn string) (cfg *Config, err error) { + // New config with some default values + cfg = &Config{ + Loc: time.UTC, + Collation: defaultCollation, + } + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.Passwd = dsn[k+1 : j] + break + } + } + cfg.User = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.Addr = dsn[k+1 : i-1] + break + } + } + cfg.Net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.DBName = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return nil, errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") + } + + } + + return +} + +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *Config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + + // Disable INFILE whitelist / enable all files + case "allowAllFiles": + var isBool bool + cfg.AllowAllFiles, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use cleartext authentication mode (MySQL 5.5.10+) + case "allowCleartextPasswords": + var isBool bool + cfg.AllowCleartextPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.AllowOldPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Switch "rowsAffected" mode + case "clientFoundRows": + var isBool bool + cfg.ClientFoundRows, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Collation + case "collation": + cfg.Collation = value + break + + case "columnsWithAlias": + var isBool bool + cfg.ColumnsWithAlias, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Compression + case "compress": + return errors.New("compression not implemented yet") + + // Enable client side placeholder substitution + case "interpolateParams": + var isBool bool + cfg.InterpolateParams, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Time Location + case "loc": + if value, err = url.QueryUnescape(value); err != nil { + return + } + cfg.Loc, err = time.LoadLocation(value) + if err != nil { + return + } + + // multiple statements in one query + case "multiStatements": + var isBool bool + cfg.MultiStatements, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // time.Time parsing + case "parseTime": + var isBool bool + cfg.ParseTime, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // I/O read Timeout + case "readTimeout": + cfg.ReadTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // Strict mode + case "strict": + var isBool bool + cfg.Strict, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Dial Timeout + case "timeout": + cfg.Timeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // TLS-Encryption + case "tls": + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.TLSConfig = "true" + cfg.tls = &tls.Config{} + } else { + cfg.TLSConfig = "false" + } + } else if vl := strings.ToLower(value); vl == "skip-verify" { + cfg.TLSConfig = vl + cfg.tls = &tls.Config{InsecureSkipVerify: true} + } else { + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for TLS config name: %v", err) + } + + if tlsConfig, ok := tlsConfigRegister[name]; ok { + if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + tlsConfig.ServerName = host + } + } + + cfg.TLSConfig = name + cfg.tls = tlsConfig + } else { + return errors.New("invalid value / unknown config name: " + name) + } + } + + // I/O write Timeout + case "writeTimeout": + cfg.WriteTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + + default: + // lazy init + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + } + + return +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/errors.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/errors.go new file mode 100644 index 0000000..1543a80 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/errors.go @@ -0,0 +1,131 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + "log" + "os" +) + +// Various errors the driver might return. Can change between driver versions. +var ( + ErrInvalidConn = errors.New("invalid connection") + ErrMalformPkt = errors.New("malformed packet") + ErrNoTLS = errors.New("TLS requested but server does not support TLS") + ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") + ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") + ErrUnknownPlugin = errors.New("this authentication plugin is not supported") + ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") + ErrPktSync = errors.New("commands out of sync. You can't run this command now") + ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrBusyBuffer = errors.New("busy buffer") +) + +var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) + +// Logger is used to log critical error messages. +type Logger interface { + Print(v ...interface{}) +} + +// SetLogger is used to set the logger for critical errors. +// The initial logger is os.Stderr. +func SetLogger(logger Logger) error { + if logger == nil { + return errors.New("logger is nil") + } + errLog = logger + return nil +} + +// MySQLError is an error type which represents a single MySQL error +type MySQLError struct { + Number uint16 + Message string +} + +func (me *MySQLError) Error() string { + return fmt.Sprintf("Error %d: %s", me.Number, me.Message) +} + +// MySQLWarnings is an error type which represents a group of one or more MySQL +// warnings +type MySQLWarnings []MySQLWarning + +func (mws MySQLWarnings) Error() string { + var msg string + for i, warning := range mws { + if i > 0 { + msg += "\r\n" + } + msg += fmt.Sprintf( + "%s %s: %s", + warning.Level, + warning.Code, + warning.Message, + ) + } + return msg +} + +// MySQLWarning is an error type which represents a single MySQL warning. +// Warnings are returned in groups only. See MySQLWarnings +type MySQLWarning struct { + Level string + Code string + Message string +} + +func (mc *mysqlConn) getWarnings() (err error) { + rows, err := mc.Query("SHOW WARNINGS", nil) + if err != nil { + return + } + + var warnings = MySQLWarnings{} + var values = make([]driver.Value, 3) + + for { + err = rows.Next(values) + switch err { + case nil: + warning := MySQLWarning{} + + if raw, ok := values[0].([]byte); ok { + warning.Level = string(raw) + } else { + warning.Level = fmt.Sprintf("%s", values[0]) + } + if raw, ok := values[1].([]byte); ok { + warning.Code = string(raw) + } else { + warning.Code = fmt.Sprintf("%s", values[1]) + } + if raw, ok := values[2].([]byte); ok { + warning.Message = string(raw) + } else { + warning.Message = fmt.Sprintf("%s", values[0]) + } + + warnings = append(warnings, warning) + + case io.EOF: + return warnings + + default: + rows.Close() + return + } + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/infile.go new file mode 100644 index 0000000..0f975bb --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/infile.go @@ -0,0 +1,181 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "fmt" + "io" + "os" + "strings" + "sync" +) + +var ( + fileRegister map[string]bool + fileRegisterLock sync.RWMutex + readerRegister map[string]func() io.Reader + readerRegisterLock sync.RWMutex +) + +// RegisterLocalFile adds the given file to the file whitelist, +// so that it can be used by "LOAD DATA LOCAL INFILE ". +// Alternatively you can allow the use of all local files with +// the DSN parameter 'allowAllFiles=true' +// +// filePath := "/home/gopher/data.csv" +// mysql.RegisterLocalFile(filePath) +// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterLocalFile(filePath string) { + fileRegisterLock.Lock() + // lazy map init + if fileRegister == nil { + fileRegister = make(map[string]bool) + } + + fileRegister[strings.Trim(filePath, `"`)] = true + fileRegisterLock.Unlock() +} + +// DeregisterLocalFile removes the given filepath from the whitelist. +func DeregisterLocalFile(filePath string) { + fileRegisterLock.Lock() + delete(fileRegister, strings.Trim(filePath, `"`)) + fileRegisterLock.Unlock() +} + +// RegisterReaderHandler registers a handler function which is used +// to receive a io.Reader. +// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::". +// If the handler returns a io.ReadCloser Close() is called when the +// request is finished. +// +// mysql.RegisterReaderHandler("data", func() io.Reader { +// var csvReader io.Reader // Some Reader that returns CSV data +// ... // Open Reader here +// return csvReader +// }) +// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterReaderHandler(name string, handler func() io.Reader) { + readerRegisterLock.Lock() + // lazy map init + if readerRegister == nil { + readerRegister = make(map[string]func() io.Reader) + } + + readerRegister[name] = handler + readerRegisterLock.Unlock() +} + +// DeregisterReaderHandler removes the ReaderHandler function with +// the given name from the registry. +func DeregisterReaderHandler(name string) { + readerRegisterLock.Lock() + delete(readerRegister, name) + readerRegisterLock.Unlock() +} + +func deferredClose(err *error, closer io.Closer) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr + } +} + +func (mc *mysqlConn) handleInFileRequest(name string) (err error) { + var rdr io.Reader + var data []byte + packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + if mc.maxWriteSize < packetSize { + packetSize = mc.maxWriteSize + } + + if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader + // The server might return an an absolute path. See issue #355. + name = name[idx+8:] + + readerRegisterLock.RLock() + handler, inMap := readerRegister[name] + readerRegisterLock.RUnlock() + + if inMap { + rdr = handler() + if rdr != nil { + if cl, ok := rdr.(io.Closer); ok { + defer deferredClose(&err, cl) + } + } else { + err = fmt.Errorf("Reader '%s' is ", name) + } + } else { + err = fmt.Errorf("Reader '%s' is not registered", name) + } + } else { // File + name = strings.Trim(name, `"`) + fileRegisterLock.RLock() + fr := fileRegister[name] + fileRegisterLock.RUnlock() + if mc.cfg.AllowAllFiles || fr { + var file *os.File + var fi os.FileInfo + + if file, err = os.Open(name); err == nil { + defer deferredClose(&err, file) + + // get file size + if fi, err = file.Stat(); err == nil { + rdr = file + if fileSize := int(fi.Size()); fileSize < packetSize { + packetSize = fileSize + } + } + } + } else { + err = fmt.Errorf("local file '%s' is not registered", name) + } + } + + // send content packets + if err == nil { + data := make([]byte, 4+packetSize) + var n int + for err == nil { + n, err = rdr.Read(data[4:]) + if n > 0 { + if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + return ioErr + } + } + } + if err == io.EOF { + err = nil + } + } + + // send empty packet (termination) + if data == nil { + data = make([]byte, 4) + } + if ioErr := mc.writePacket(data[:4]); ioErr != nil { + return ioErr + } + + // read OK packet + if err == nil { + return mc.readResultOK() + } + + mc.readPacket() + return err +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/packets.go new file mode 100644 index 0000000..8d91665 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/packets.go @@ -0,0 +1,1243 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/tls" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "time" +) + +// Packets documentation: +// http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +// Read packet to buffer 'data' +func (mc *mysqlConn) readPacket() ([]byte, error) { + var payload []byte + for { + // Read packet header + data, err := mc.buf.readNext(4) + if err != nil { + errLog.Print(err) + mc.Close() + return nil, driver.ErrBadConn + } + + // Packet Length [24 bit] + pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + + if pktLen < 1 { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, driver.ErrBadConn + } + + // Check Packet Sync [8 bit] + if data[3] != mc.sequence { + if data[3] > mc.sequence { + return nil, ErrPktSyncMul + } + return nil, ErrPktSync + } + mc.sequence++ + + // Read packet body [pktLen bytes] + data, err = mc.buf.readNext(pktLen) + if err != nil { + errLog.Print(err) + mc.Close() + return nil, driver.ErrBadConn + } + + isLastPacket := (pktLen < maxPacketSize) + + // Zero allocations for non-splitting packets + if isLastPacket && payload == nil { + return data, nil + } + + payload = append(payload, data...) + + if isLastPacket { + return payload, nil + } + } +} + +// Write packet buffer 'data' +func (mc *mysqlConn) writePacket(data []byte) error { + pktLen := len(data) - 4 + + if pktLen > mc.maxPacketAllowed { + return ErrPktTooLarge + } + + for { + var size int + if pktLen >= maxPacketSize { + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + size = maxPacketSize + } else { + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + size = pktLen + } + data[3] = mc.sequence + + // Write packet + if mc.writeTimeout > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + return err + } + } + + n, err := mc.netConn.Write(data[:4+size]) + if err == nil && n == 4+size { + mc.sequence++ + if size != maxPacketSize { + return nil + } + pktLen -= size + data = data[size:] + continue + } + + // Handle error + if err == nil { // n != len(data) + errLog.Print(ErrMalformPkt) + } else { + errLog.Print(err) + } + return driver.ErrBadConn + } +} + +/****************************************************************************** +* Initialisation Process * +******************************************************************************/ + +// Handshake Initialization Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +func (mc *mysqlConn) readInitPacket() ([]byte, error) { + data, err := mc.readPacket() + if err != nil { + return nil, err + } + + if data[0] == iERR { + return nil, mc.handleErrorPacket(data) + } + + // protocol version [1 byte] + if data[0] < minProtocolVersion { + return nil, fmt.Errorf( + "unsupported protocol version %d. Version %d or higher is required", + data[0], + minProtocolVersion, + ) + } + + // server version [null terminated string] + // connection id [4 bytes] + pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 + + // first part of the password cipher [8 bytes] + cipher := data[pos : pos+8] + + // (filler) always 0x00 [1 byte] + pos += 8 + 1 + + // capability flags (lower 2 bytes) [2 bytes] + mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if mc.flags&clientProtocol41 == 0 { + return nil, ErrOldProtocol + } + if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { + return nil, ErrNoTLS + } + pos += 2 + + if len(data) > pos { + // character set [1 byte] + // status flags [2 bytes] + // capability flags (upper 2 bytes) [2 bytes] + // length of auth-plugin-data [1 byte] + // reserved (all [00]) [10 bytes] + pos += 1 + 2 + 2 + 1 + 10 + + // second part of the password cipher [mininum 13 bytes], + // where len=MAX(13, length of auth-plugin-data - 8) + // + // The web documentation is ambiguous about the length. However, + // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, + // the 13th byte is "\0 byte, terminating the second part of + // a scramble". So the second part of the password cipher is + // a NULL terminated string that's at least 13 bytes with the + // last byte being NULL. + // + // The official Python library uses the fixed length 12 + // which seems to work but technically could have a hidden bug. + cipher = append(cipher, data[pos:pos+12]...) + + // TODO: Verify string termination + // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) + // \NUL otherwise + // + //if data[len(data)-1] == 0 { + // return + //} + //return ErrMalformPkt + + // make a memory safe copy of the cipher slice + var b [20]byte + copy(b[:], cipher) + return b[:], nil + } + + // make a memory safe copy of the cipher slice + var b [8]byte + copy(b[:], cipher) + return b[:], nil +} + +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { + // Adjust client flags based on server support + clientFlags := clientProtocol41 | + clientSecureConn | + clientLongPassword | + clientTransactions | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + mc.flags&clientLongFlag + + if mc.cfg.ClientFoundRows { + clientFlags |= clientFoundRows + } + + // To enable TLS / SSL + if mc.cfg.tls != nil { + clientFlags |= clientSSL + } + + if mc.cfg.MultiStatements { + clientFlags |= clientMultiStatements + } + + // User Password + scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 + + // To specify a db name + if n := len(mc.cfg.DBName); n > 0 { + clientFlags |= clientConnectWithDB + pktLen += n + 1 + } + + // Calculate packet length and get buffer with that size + data := mc.buf.takeSmallBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // ClientFlags [32 bit] + data[4] = byte(clientFlags) + data[5] = byte(clientFlags >> 8) + data[6] = byte(clientFlags >> 16) + data[7] = byte(clientFlags >> 24) + + // MaxPacketSize [32 bit] (none) + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 + + // Charset [1 byte] + var found bool + data[12], found = collations[mc.cfg.Collation] + if !found { + // Note possibility for false negatives: + // could be triggered although the collation is valid if the + // collations map does not contain entries the server supports. + return errors.New("unknown collation") + } + + // SSL Connection Request Packet + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + if mc.cfg.tls != nil { + // Send TLS / SSL request packet + if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + return err + } + + // Switch to TLS + tlsConn := tls.Client(mc.netConn, mc.cfg.tls) + if err := tlsConn.Handshake(); err != nil { + return err + } + mc.netConn = tlsConn + mc.buf.nc = tlsConn + } + + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + + // User [null terminated string] + if len(mc.cfg.User) > 0 { + pos += copy(data[pos:], mc.cfg.User) + } + data[pos] = 0x00 + pos++ + + // ScrambleBuffer [length encoded integer] + data[pos] = byte(len(scrambleBuff)) + pos += 1 + copy(data[pos+1:], scrambleBuff) + + // Databasename [null terminated string] + if len(mc.cfg.DBName) > 0 { + pos += copy(data[pos:], mc.cfg.DBName) + data[pos] = 0x00 + pos++ + } + + // Assume native client during response + pos += copy(data[pos:], "mysql_native_password") + data[pos] = 0x00 + + // Send Auth packet + return mc.writePacket(data) +} + +// Client old authentication packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { + // User password + scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) + + // Calculate the packet length and add a tailing 0 + pktLen := len(scrambleBuff) + 1 + data := mc.buf.takeSmallBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add the scrambled password [null terminated string] + copy(data[4:], scrambleBuff) + data[4+pktLen-1] = 0x00 + + return mc.writePacket(data) +} + +// Client clear text authentication packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeClearAuthPacket() error { + // Calculate the packet length and add a tailing 0 + pktLen := len(mc.cfg.Passwd) + 1 + data := mc.buf.takeSmallBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add the clear password [null terminated string] + copy(data[4:], mc.cfg.Passwd) + data[4+pktLen-1] = 0x00 + + return mc.writePacket(data) +} + +/****************************************************************************** +* Command Packets * +******************************************************************************/ + +func (mc *mysqlConn) writeCommandPacket(command byte) error { + // Reset Packet Sequence + mc.sequence = 0 + + data := mc.buf.takeSmallBuffer(4 + 1) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add command byte + data[4] = command + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { + // Reset Packet Sequence + mc.sequence = 0 + + pktLen := 1 + len(arg) + data := mc.buf.takeBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add command byte + data[4] = command + + // Add arg + copy(data[5:], arg) + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { + // Reset Packet Sequence + mc.sequence = 0 + + data := mc.buf.takeSmallBuffer(4 + 1 + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add command byte + data[4] = command + + // Add arg [32 bit] + data[5] = byte(arg) + data[6] = byte(arg >> 8) + data[7] = byte(arg >> 16) + data[8] = byte(arg >> 24) + + // Send CMD packet + return mc.writePacket(data) +} + +/****************************************************************************** +* Result Packets * +******************************************************************************/ + +// Returns error if Packet is not an 'Result OK'-Packet +func (mc *mysqlConn) readResultOK() error { + data, err := mc.readPacket() + if err == nil { + // packet indicator + switch data[0] { + + case iOK: + return mc.handleOkPacket(data) + + case iEOF: + if len(data) > 1 { + plugin := string(data[1:bytes.IndexByte(data, 0x00)]) + if plugin == "mysql_old_password" { + // using old_passwords + return ErrOldPassword + } else if plugin == "mysql_clear_password" { + // using clear text password + return ErrCleartextPassword + } else { + return ErrUnknownPlugin + } + } else { + return ErrOldPassword + } + + default: // Error otherwise + return mc.handleErrorPacket(data) + } + } + return err +} + +// Result Set Header Packet +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset +func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { + data, err := mc.readPacket() + if err == nil { + switch data[0] { + + case iOK: + return 0, mc.handleOkPacket(data) + + case iERR: + return 0, mc.handleErrorPacket(data) + + case iLocalInFile: + return 0, mc.handleInFileRequest(string(data[1:])) + } + + // column count + num, _, n := readLengthEncodedInteger(data) + if n-len(data) == 0 { + return int(num), nil + } + + return 0, ErrMalformPkt + } + return 0, err +} + +// Error Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet +func (mc *mysqlConn) handleErrorPacket(data []byte) error { + if data[0] != iERR { + return ErrMalformPkt + } + + // 0xff [1 byte] + + // Error Number [16 bit uint] + errno := binary.LittleEndian.Uint16(data[1:3]) + + pos := 3 + + // SQL State [optional: # + 5bytes string] + if data[3] == 0x23 { + //sqlstate := string(data[4 : 4+5]) + pos = 9 + } + + // Error Message [string] + return &MySQLError{ + Number: errno, + Message: string(data[pos:]), + } +} + +func readStatus(b []byte) statusFlag { + return statusFlag(b[0]) | statusFlag(b[1])<<8 +} + +// Ok Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +func (mc *mysqlConn) handleOkPacket(data []byte) error { + var n, m int + + // 0x00 [1 byte] + + // Affected rows [Length Coded Binary] + mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + + // Insert id [Length Coded Binary] + mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + + // server_status [2 bytes] + mc.status = readStatus(data[1+n+m : 1+n+m+2]) + if err := mc.discardResults(); err != nil { + return err + } + + // warning count [2 bytes] + if !mc.strict { + return nil + } + + pos := 1 + n + m + 2 + if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { + return mc.getWarnings() + } + return nil +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 +func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { + columns := make([]mysqlField, count) + + for i := 0; ; i++ { + data, err := mc.readPacket() + if err != nil { + return nil, err + } + + // EOF Packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if i == count { + return columns, nil + } + return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) + } + + // Catalog + pos, err := skipLengthEncodedString(data) + if err != nil { + return nil, err + } + + // Database [len coded string] + n, err := skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Table [len coded string] + if mc.cfg.ColumnsWithAlias { + tableName, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + columns[i].tableName = string(tableName) + } else { + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + } + + // Original table [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Name [len coded string] + name, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + columns[i].name = string(name) + pos += n + + // Original name [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + + // Filler [uint8] + // Charset [charset, collation uint8] + // Length [uint32] + pos += n + 1 + 2 + 4 + + // Field type [uint8] + columns[i].fieldType = data[pos] + pos++ + + // Flags [uint16] + columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + pos += 2 + + // Decimals [uint8] + columns[i].decimals = data[pos] + //pos++ + + // Default value [len coded binary] + //if pos < len(data) { + // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) + //} + } +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow +func (rows *textRows) readRow(dest []driver.Value) error { + mc := rows.mc + + data, err := mc.readPacket() + if err != nil { + return err + } + + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + if err := rows.mc.discardResults(); err != nil { + return err + } + rows.mc = nil + return io.EOF + } + if data[0] == iERR { + rows.mc = nil + return mc.handleErrorPacket(data) + } + + // RowSet Packet + var n int + var isNull bool + pos := 0 + + for i := range dest { + // Read bytes and convert to string + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + if err == nil { + if !isNull { + if !mc.parseTime { + continue + } else { + switch rows.columns[i].fieldType { + case fieldTypeTimestamp, fieldTypeDateTime, + fieldTypeDate, fieldTypeNewDate: + dest[i], err = parseDateTime( + string(dest[i].([]byte)), + mc.cfg.Loc, + ) + if err == nil { + continue + } + default: + continue + } + } + + } else { + dest[i] = nil + continue + } + } + return err // err != nil + } + + return nil +} + +// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read +func (mc *mysqlConn) readUntilEOF() error { + for { + data, err := mc.readPacket() + + // No Err and no EOF Packet + if err == nil && data[0] != iEOF { + continue + } + if err == nil && data[0] == iEOF && len(data) == 5 { + mc.status = readStatus(data[3:]) + } + + return err // Err or EOF + } +} + +/****************************************************************************** +* Prepared Statements * +******************************************************************************/ + +// Prepare Result Packets +// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html +func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { + data, err := stmt.mc.readPacket() + if err == nil { + // packet indicator [1 byte] + if data[0] != iOK { + return 0, stmt.mc.handleErrorPacket(data) + } + + // statement id [4 bytes] + stmt.id = binary.LittleEndian.Uint32(data[1:5]) + + // Column count [16 bit uint] + columnCount := binary.LittleEndian.Uint16(data[5:7]) + + // Param count [16 bit uint] + stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) + + // Reserved [8 bit] + + // Warning count [16 bit uint] + if !stmt.mc.strict { + return columnCount, nil + } + + // Check for warnings count > 0, only available in MySQL > 4.1 + if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { + return columnCount, stmt.mc.getWarnings() + } + return columnCount, nil + } + return 0, err +} + +// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html +func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { + maxLen := stmt.mc.maxPacketAllowed - 1 + pktLen := maxLen + + // After the header (bytes 0-3) follows before the data: + // 1 byte command + // 4 bytes stmtID + // 2 bytes paramID + const dataOffset = 1 + 4 + 2 + + // Can not use the write buffer since + // a) the buffer is too small + // b) it is in use + data := make([]byte, 4+1+4+2+len(arg)) + + copy(data[4+dataOffset:], arg) + + for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { + if dataOffset+argLen < maxLen { + pktLen = dataOffset + argLen + } + + stmt.mc.sequence = 0 + // Add command byte [1 byte] + data[4] = comStmtSendLongData + + // Add stmtID [32 bit] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // Add paramID [16 bit] + data[9] = byte(paramID) + data[10] = byte(paramID >> 8) + + // Send CMD packet + err := stmt.mc.writePacket(data[:4+pktLen]) + if err == nil { + data = data[pktLen-dataOffset:] + continue + } + return err + + } + + // Reset Packet Sequence + stmt.mc.sequence = 0 + return nil +} + +// Execute Prepared Statement +// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html +func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { + if len(args) != stmt.paramCount { + return fmt.Errorf( + "argument count mismatch (got: %d; has: %d)", + len(args), + stmt.paramCount, + ) + } + + const minPktLen = 4 + 1 + 4 + 1 + 4 + mc := stmt.mc + + // Reset packet-sequence + mc.sequence = 0 + + var data []byte + + if len(args) == 0 { + data = mc.buf.takeBuffer(minPktLen) + } else { + data = mc.buf.takeCompleteBuffer() + } + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // command [1 byte] + data[4] = comStmtExecute + + // statement_id [4 bytes] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] + data[9] = 0x00 + + // iteration_count (uint32(1)) [4 bytes] + data[10] = 0x01 + data[11] = 0x00 + data[12] = 0x00 + data[13] = 0x00 + + if len(args) > 0 { + pos := minPktLen + + var nullMask []byte + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + // buffer has to be extended but we don't know by how much so + // we depend on append after all data with known sizes fit. + // We stop at that because we deal with a lot of columns here + // which makes the required allocation size hard to guess. + tmp := make([]byte, pos+maskLen+typesLen) + copy(tmp[:pos], data[:pos]) + data = tmp + nullMask = data[pos : pos+maskLen] + pos += maskLen + } else { + nullMask = data[pos : pos+maskLen] + for i := 0; i < maskLen; i++ { + nullMask[i] = 0 + } + pos += maskLen + } + + // newParameterBoundFlag 1 [1 byte] + data[pos] = 0x01 + pos++ + + // type of each parameter [len(args)*2 bytes] + paramTypes := data[pos:] + pos += len(args) * 2 + + // value of each parameter [n bytes] + paramValues := data[pos:pos] + valuesCap := cap(paramValues) + + for i, arg := range args { + // build NULL-bitmap + if arg == nil { + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i+1] = 0x00 + continue + } + + // cache types and values + switch v := arg.(type) { + case int64: + paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + + case float64: + paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + math.Float64bits(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(math.Float64bits(v))..., + ) + } + + case bool: + paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i+1] = 0x00 + + if v { + paramValues = append(paramValues, 0x01) + } else { + paramValues = append(paramValues, 0x00) + } + + case []byte: + // Common case (non-nil value) first + if v != nil { + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, v); err != nil { + return err + } + } + continue + } + + // Handle []byte(nil) as a NULL value + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i+1] = 0x00 + + case string: + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err + } + } + + case time.Time: + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + var val []byte + if v.IsZero() { + val = []byte("0000-00-00") + } else { + val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) + } + + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(val)), + ) + paramValues = append(paramValues, val...) + + default: + return fmt.Errorf("can not convert type: %T", arg) + } + } + + // Check if param values exceeded the available buffer + // In that case we must build the data packet with the new values buffer + if valuesCap != cap(paramValues) { + data = append(data[:pos], paramValues...) + mc.buf.buf = data + } + + pos += len(paramValues) + data = data[:pos] + } + + return mc.writePacket(data) +} + +func (mc *mysqlConn) discardResults() error { + for mc.status&statusMoreResultsExists != 0 { + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } else { + mc.status &^= statusMoreResultsExists + } + } + return nil +} + +// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html +func (rows *binaryRows) readRow(dest []driver.Value) error { + data, err := rows.mc.readPacket() + if err != nil { + return err + } + + // packet indicator [1 byte] + if data[0] != iOK { + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + rows.mc.status = readStatus(data[3:]) + if err := rows.mc.discardResults(); err != nil { + return err + } + rows.mc = nil + return io.EOF + } + rows.mc = nil + + // Error otherwise + return rows.mc.handleErrorPacket(data) + } + + // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] + pos := 1 + (len(dest)+7+2)>>3 + nullMask := data[1:pos] + + for i := range dest { + // Field is NULL + // (byte >> bit-pos) % 2 == 1 + if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { + dest[i] = nil + continue + } + + // Convert to byte-coded string + switch rows.columns[i].fieldType { + case fieldTypeNULL: + dest[i] = nil + continue + + // Numeric Types + case fieldTypeTiny: + if rows.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(data[pos]) + } else { + dest[i] = int64(int8(data[pos])) + } + pos++ + continue + + case fieldTypeShort, fieldTypeYear: + if rows.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) + } else { + dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) + } + pos += 2 + continue + + case fieldTypeInt24, fieldTypeLong: + if rows.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) + } else { + dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) + } + pos += 4 + continue + + case fieldTypeLongLong: + if rows.columns[i].flags&flagUnsigned != 0 { + val := binary.LittleEndian.Uint64(data[pos : pos+8]) + if val > math.MaxInt64 { + dest[i] = uint64ToString(val) + } else { + dest[i] = int64(val) + } + } else { + dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) + } + pos += 8 + continue + + case fieldTypeFloat: + dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + pos += 4 + continue + + case fieldTypeDouble: + dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) + pos += 8 + continue + + // Length coded Binary Strings + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: + var isNull bool + var n int + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + if err == nil { + if !isNull { + continue + } else { + dest[i] = nil + continue + } + } + return err + + case + fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD + fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] + fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] + + num, isNull, n := readLengthEncodedInteger(data[pos:]) + pos += n + + switch { + case isNull: + dest[i] = nil + continue + case rows.columns[i].fieldType == fieldTypeTime: + // database/sql does not support an equivalent to TIME, return a string + var dstlen uint8 + switch decimals := rows.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 8 + case 1, 2, 3, 4, 5, 6: + dstlen = 8 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.columns[i].decimals, + ) + } + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + case rows.mc.parseTime: + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) + default: + var dstlen uint8 + if rows.columns[i].fieldType == fieldTypeDate { + dstlen = 10 + } else { + switch decimals := rows.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 19 + case 1, 2, 3, 4, 5, 6: + dstlen = 19 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.columns[i].decimals, + ) + } + } + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + } + + if err == nil { + pos += int(num) + continue + } else { + return err + } + + // Please report if this happens! + default: + return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + } + } + + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/result.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/result.go new file mode 100644 index 0000000..c6438d0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/result.go @@ -0,0 +1,22 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +type mysqlResult struct { + affectedRows int64 + insertId int64 +} + +func (res *mysqlResult) LastInsertId() (int64, error) { + return res.insertId, nil +} + +func (res *mysqlResult) RowsAffected() (int64, error) { + return res.affectedRows, nil +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/rows.go new file mode 100644 index 0000000..c08255e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/rows.go @@ -0,0 +1,112 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "io" +) + +type mysqlField struct { + tableName string + name string + flags fieldFlag + fieldType byte + decimals byte +} + +type mysqlRows struct { + mc *mysqlConn + columns []mysqlField +} + +type binaryRows struct { + mysqlRows +} + +type textRows struct { + mysqlRows +} + +type emptyRows struct{} + +func (rows *mysqlRows) Columns() []string { + columns := make([]string, len(rows.columns)) + if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { + for i := range columns { + if tableName := rows.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.columns[i].name + } else { + columns[i] = rows.columns[i].name + } + } + } else { + for i := range columns { + columns[i] = rows.columns[i].name + } + } + return columns +} + +func (rows *mysqlRows) Close() error { + mc := rows.mc + if mc == nil { + return nil + } + if mc.netConn == nil { + return ErrInvalidConn + } + + // Remove unread packets from stream + err := mc.readUntilEOF() + if err == nil { + if err = mc.discardResults(); err != nil { + return err + } + } + + rows.mc = nil + return err +} + +func (rows *binaryRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if mc.netConn == nil { + return ErrInvalidConn + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} + +func (rows *textRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if mc.netConn == nil { + return ErrInvalidConn + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} + +func (rows emptyRows) Columns() []string { + return nil +} + +func (rows emptyRows) Close() error { + return nil +} + +func (rows emptyRows) Next(dest []driver.Value) error { + return io.EOF +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/statement.go new file mode 100644 index 0000000..ead9a6b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/statement.go @@ -0,0 +1,150 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "fmt" + "reflect" + "strconv" +) + +type mysqlStmt struct { + mc *mysqlConn + id uint32 + paramCount int + columns []mysqlField // cached from the first query +} + +func (stmt *mysqlStmt) Close() error { + if stmt.mc == nil || stmt.mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + stmt.mc = nil + return err +} + +func (stmt *mysqlStmt) NumInput() int { + return stmt.paramCount +} + +func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { + return converter{} +} + +func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, err + } + + mc := stmt.mc + + mc.affectedRows = 0 + mc.insertId = 0 + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil { + if resLen > 0 { + // Columns + err = mc.readUntilEOF() + if err != nil { + return nil, err + } + + // Rows + err = mc.readUntilEOF() + } + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil + } + } + + return nil, err +} + +func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, err + } + + mc := stmt.mc + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + rows := new(binaryRows) + + if resLen > 0 { + rows.mc = mc + // Columns + // If not cached, read them and cache them + if stmt.columns == nil { + rows.columns, err = mc.readColumns(resLen) + stmt.columns = rows.columns + } else { + rows.columns = stmt.columns + err = mc.readUntilEOF() + } + } + + return rows, err +} + +type converter struct{} + +func (c converter) ConvertValue(v interface{}) (driver.Value, error) { + if driver.IsValue(v) { + return v, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Ptr: + // indirect pointers + if rv.IsNil() { + return nil, nil + } + return c.ConvertValue(rv.Elem().Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return strconv.FormatUint(u64, 10), nil + } + return int64(u64), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/transaction.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/transaction.go new file mode 100644 index 0000000..33c749b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/transaction.go @@ -0,0 +1,31 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +type mysqlTx struct { + mc *mysqlConn +} + +func (tx *mysqlTx) Commit() (err error) { + if tx.mc == nil || tx.mc.netConn == nil { + return ErrInvalidConn + } + err = tx.mc.exec("COMMIT") + tx.mc = nil + return +} + +func (tx *mysqlTx) Rollback() (err error) { + if tx.mc == nil || tx.mc.netConn == nil { + return ErrInvalidConn + } + err = tx.mc.exec("ROLLBACK") + tx.mc = nil + return +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/utils.go new file mode 100644 index 0000000..d523b7f --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/go-sql-driver/mysql/utils.go @@ -0,0 +1,740 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/sha1" + "crypto/tls" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "strings" + "time" +) + +var ( + tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs +) + +// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. +// Use the key as a value in the DSN where tls=value. +// +// rootCertPool := x509.NewCertPool() +// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// if err != nil { +// log.Fatal(err) +// } +// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { +// log.Fatal("Failed to append PEM.") +// } +// clientCert := make([]tls.Certificate, 0, 1) +// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") +// if err != nil { +// log.Fatal(err) +// } +// clientCert = append(clientCert, certs) +// mysql.RegisterTLSConfig("custom", &tls.Config{ +// RootCAs: rootCertPool, +// Certificates: clientCert, +// }) +// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") +// +func RegisterTLSConfig(key string, config *tls.Config) error { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { + return fmt.Errorf("key '%s' is reserved", key) + } + + if tlsConfigRegister == nil { + tlsConfigRegister = make(map[string]*tls.Config) + } + + tlsConfigRegister[key] = config + return nil +} + +// DeregisterTLSConfig removes the tls.Config associated with key. +func DeregisterTLSConfig(key string) { + if tlsConfigRegister != nil { + delete(tlsConfigRegister, key) + } +} + +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true + } + + // Not a valid bool value + return +} + +/****************************************************************************** +* Authentication * +******************************************************************************/ + +// Encrypt password using 4.1+ method +func scramblePassword(scramble, password []byte) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write(password) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +// Encrypt password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +// Pseudo random number generator +func newMyRnd(seed1, seed2 uint32) *myRnd { + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } +} + +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) +} + +// Generate binary hash from byte string using insecure pre 4.1 method +func pwHash(password []byte) (result [2]uint32) { + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 + + for _, c := range password { + // skip spaces and tabs in password + if c == ' ' || c == '\t' { + continue + } + + tmp = uint32(c) + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] + add += tmp + } + + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + + return +} + +// Encrypt password using insecure pre 4.1 method +func scrambleOldPassword(scramble, password []byte) []byte { + if len(password) == 0 { + return nil + } + + scramble = scramble[:8] + + hashPw := pwHash(password) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} + +/****************************************************************************** +* Time related utils * +******************************************************************************/ + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = parseDateTime(string(v), time.UTC) + nt.Valid = (err == nil) + return + case string: + nt.Time, err = parseDateTime(v, time.UTC) + nt.Valid = (err == nil) + return + } + + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + +func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { + base := "0000-00-00 00:00:00.0000000" + switch len(str) { + case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" + if str == base[:len(str)] { + return + } + t, err = time.Parse(timeFormat[:len(str)], str) + default: + err = fmt.Errorf("invalid time string: %s", str) + return + } + + // Adjust location + if err == nil && loc != time.UTC { + y, mo, d := t.Date() + h, mi, s := t.Clock() + t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil + } + + return +} + +func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { + switch num { + case 0: + return time.Time{}, nil + case 4: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + 0, 0, 0, 0, + loc, + ), nil + case 7: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + int(data[4]), // hour + int(data[5]), // minutes + int(data[6]), // seconds + 0, + loc, + ), nil + case 11: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + int(data[4]), // hour + int(data[5]), // minutes + int(data[6]), // seconds + int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds + loc, + ), nil + } + return nil, fmt.Errorf("invalid DATETIME packet length %d", num) +} + +// zeroDateTime is used in formatBinaryDateTime to avoid an allocation +// if the DATE or DATETIME has the zero value. +// It must never be changed. +// The current behavior depends on database/sql copying the result. +var zeroDateTime = []byte("0000-00-00 00:00:00.000000") + +const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" + +func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + if justTime { + return zeroDateTime[11 : 11+length], nil + } + return zeroDateTime[:length], nil + } + var dst []byte // return value + var pt, p1, p2, p3 byte // current digit pair + var zOffs byte // offset of value in zeroDateTime + if justTime { + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds + default: + return nil, fmt.Errorf("illegal TIME length %d", length) + } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) + } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + if src[1] != 0 { + hour := uint16(src[1])*24 + uint16(src[5]) + pt = byte(hour / 100) + p1 = byte(hour - 100*uint16(pt)) + dst = append(dst, digits01[pt]) + } else { + p1 = src[5] + } + zOffs = 11 + src = src[6:] + } else { + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s length %d", t, length) + } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt = byte(year / 100) + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + } + // p1 is 2-digit hour, src is after hour + p2, p3 = src[0], src[1] + dst = append(dst, + digits10[p1], digits01[p1], ':', + digits10[p2], digits01[p2], ':', + digits10[p3], digits01[p3], + ) + if length <= byte(len(dst)) { + return dst, nil + } + src = src[2:] + if len(src) == 0 { + return append(dst, zeroDateTime[19:zOffs+length]...), nil + } + microsecs := binary.LittleEndian.Uint32(src[:4]) + p1 = byte(microsecs / 10000) + microsecs -= 10000 * uint32(p1) + p2 = byte(microsecs / 100) + microsecs -= 100 * uint32(p2) + p3 = byte(microsecs) + switch decimals := zOffs + length - 20; decimals { + default: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], digits01[p3], + ), nil + case 1: + return append(dst, '.', + digits10[p1], + ), nil + case 2: + return append(dst, '.', + digits10[p1], digits01[p1], + ), nil + case 3: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], + ), nil + case 4: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + ), nil + case 5: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], + ), nil + } +} + +/****************************************************************************** +* Convert from and to bytes * +******************************************************************************/ + +func uint64ToBytes(n uint64) []byte { + return []byte{ + byte(n), + byte(n >> 8), + byte(n >> 16), + byte(n >> 24), + byte(n >> 32), + byte(n >> 40), + byte(n >> 48), + byte(n >> 56), + } +} + +func uint64ToString(n uint64) []byte { + var a [20]byte + i := 20 + + // U+0030 = 0 + // ... + // U+0039 = 9 + + var q uint64 + for n >= 10 { + i-- + q = n / 10 + a[i] = uint8(n-q*10) + 0x30 + n = q + } + + i-- + a[i] = uint8(n) + 0x30 + + return a[i:] +} + +// treats string value as unsigned integer representation +func stringToInt(b []byte) int { + val := 0 + for i := range b { + val *= 10 + val += int(b[i] - 0x30) + } + return val +} + +// returns the string read as a bytes slice, wheter the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice +func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { + // Get length + num, isNull, n := readLengthEncodedInteger(b) + if num < 1 { + return b[n:n], isNull, n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return b[n-int(num) : n], false, n, nil + } + return nil, false, n, io.EOF +} + +// returns the number of bytes skipped and an error, in case the string is +// longer than the input slice +func skipLengthEncodedString(b []byte) (int, error) { + // Get length + num, _, n := readLengthEncodedInteger(b) + if num < 1 { + return n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return n, nil + } + return n, io.EOF +} + +// returns the number read, whether the value is NULL and the number of bytes read +func readLengthEncodedInteger(b []byte) (uint64, bool, int) { + // See issue #349 + if len(b) == 0 { + return 0, true, 1 + } + switch b[0] { + + // 251: NULL + case 0xfb: + return 0, true, 1 + + // 252: value of following 2 + case 0xfc: + return uint64(b[1]) | uint64(b[2])<<8, false, 3 + + // 253: value of following 3 + case 0xfd: + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 + + // 254: value of following 8 + case 0xfe: + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | + uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | + uint64(b[7])<<48 | uint64(b[8])<<56, + false, 9 + } + + // 0-250: value of first byte + return uint64(b[0]), false, 1 +} + +// encodes a uint64 value and appends it to the given bytes slice +func appendLengthEncodedInteger(b []byte, n uint64) []byte { + switch { + case n <= 250: + return append(b, byte(n)) + + case n <= 0xffff: + return append(b, 0xfc, byte(n), byte(n>>8)) + + case n <= 0xffffff: + return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) + } + return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), + byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) +} + +// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. +// If cap(buf) is not enough, reallocate new buffer. +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + // Grow buffer exponentially + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash escapes []byte with backslashes (\) +// This escapes the contents of a string (provided as []byte) by adding backslashes before special +// characters, and turning others into specific escape sequences, such as +// turning newlines into \n and null bytes into \0. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 +func escapeBytesBackslash(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash is similar to escapeBytesBackslash but for string. +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. +// This escapes the contents of a string by doubling up any apostrophes that +// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in +// effect on the server. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 +func escapeBytesQuotes(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringQuotes is similar to escapeBytesQuotes but for string. +func escapeStringQuotes(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/LICENSE new file mode 100644 index 0000000..0d31edf --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/LICENSE @@ -0,0 +1,23 @@ + Copyright (c) 2013, Jason Moiron + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without + restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following + conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/bind.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/bind.go new file mode 100644 index 0000000..564635c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/bind.go @@ -0,0 +1,186 @@ +package sqlx + +import ( + "bytes" + "errors" + "reflect" + "strconv" + "strings" + + "github.com/jmoiron/sqlx/reflectx" +) + +// Bindvar types supported by Rebind, BindMap and BindStruct. +const ( + UNKNOWN = iota + QUESTION + DOLLAR + NAMED +) + +// BindType returns the bindtype for a given database given a drivername. +func BindType(driverName string) int { + switch driverName { + case "postgres", "pgx": + return DOLLAR + case "mysql": + return QUESTION + case "sqlite3": + return QUESTION + case "oci8": + return NAMED + } + return UNKNOWN +} + +// FIXME: this should be able to be tolerant of escaped ?'s in queries without +// losing much speed, and should be to avoid confusion. + +// Rebind a query from the default bindtype (QUESTION) to the target bindtype. +func Rebind(bindType int, query string) string { + switch bindType { + case QUESTION, UNKNOWN: + return query + } + + qb := []byte(query) + // Add space enough for 10 params before we have to allocate + rqb := make([]byte, 0, len(qb)+10) + j := 1 + for _, b := range qb { + if b == '?' { + switch bindType { + case DOLLAR: + rqb = append(rqb, '$') + case NAMED: + rqb = append(rqb, ':', 'a', 'r', 'g') + } + for _, b := range strconv.Itoa(j) { + rqb = append(rqb, byte(b)) + } + j++ + } else { + rqb = append(rqb, b) + } + } + return string(rqb) +} + +// Experimental implementation of Rebind which uses a bytes.Buffer. The code is +// much simpler and should be more resistant to odd unicode, but it is twice as +// slow. Kept here for benchmarking purposes and to possibly replace Rebind if +// problems arise with its somewhat naive handling of unicode. +func rebindBuff(bindType int, query string) string { + if bindType != DOLLAR { + return query + } + + b := make([]byte, 0, len(query)) + rqb := bytes.NewBuffer(b) + j := 1 + for _, r := range query { + if r == '?' { + rqb.WriteRune('$') + rqb.WriteString(strconv.Itoa(j)) + j++ + } else { + rqb.WriteRune(r) + } + } + + return rqb.String() +} + +// In expands slice values in args, returning the modified query string +// and a new arg list that can be executed by a database. The `query` should +// use the `?` bindVar. The return value uses the `?` bindVar. +func In(query string, args ...interface{}) (string, []interface{}, error) { + // argMeta stores reflect.Value and length for slices and + // the value itself for non-slice arguments + type argMeta struct { + v reflect.Value + i interface{} + length int + } + + var flatArgsCount int + var anySlices bool + + meta := make([]argMeta, len(args)) + + for i, arg := range args { + v := reflect.ValueOf(arg) + t := reflectx.Deref(v.Type()) + + if t.Kind() == reflect.Slice { + meta[i].length = v.Len() + meta[i].v = v + + anySlices = true + flatArgsCount += meta[i].length + + if meta[i].length == 0 { + return "", nil, errors.New("empty slice passed to 'in' query") + } + } else { + meta[i].i = arg + flatArgsCount++ + } + } + + // don't do any parsing if there aren't any slices; note that this means + // some errors that we might have caught below will not be returned. + if !anySlices { + return query, args, nil + } + + newArgs := make([]interface{}, 0, flatArgsCount) + + var arg, offset int + var buf bytes.Buffer + + for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { + if arg >= len(meta) { + // if an argument wasn't passed, lets return an error; this is + // not actually how database/sql Exec/Query works, but since we are + // creating an argument list programmatically, we want to be able + // to catch these programmer errors earlier. + return "", nil, errors.New("number of bindVars exceeds arguments") + } + + argMeta := meta[arg] + arg++ + + // not a slice, continue. + // our questionmark will either be written before the next expansion + // of a slice or after the loop when writing the rest of the query + if argMeta.length == 0 { + offset = offset + i + 1 + newArgs = append(newArgs, argMeta.i) + continue + } + + // write everything up to and including our ? character + buf.WriteString(query[:offset+i+1]) + + newArgs = append(newArgs, argMeta.v.Index(0).Interface()) + + for si := 1; si < argMeta.length; si++ { + buf.WriteString(", ?") + newArgs = append(newArgs, argMeta.v.Index(si).Interface()) + } + + // slice the query and reset the offset. this avoids some bookkeeping for + // the write after the loop + query = query[offset+i+1:] + offset = 0 + } + + buf.WriteString(query) + + if arg < len(meta) { + return "", nil, errors.New("number of bindVars less than number arguments") + } + + return buf.String(), newArgs, nil +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/doc.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/doc.go new file mode 100644 index 0000000..e2b4e60 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/doc.go @@ -0,0 +1,12 @@ +// Package sqlx provides general purpose extensions to database/sql. +// +// It is intended to seamlessly wrap database/sql and provide convenience +// methods which are useful in the development of database driven applications. +// None of the underlying database/sql methods are changed. Instead all extended +// behavior is implemented through new methods defined on wrapper types. +// +// Additions include scanning into structs, named query support, rebinding +// queries for different drivers, convenient shorthands for common error handling +// and more. +// +package sqlx diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/named.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/named.go new file mode 100644 index 0000000..4df8095 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/named.go @@ -0,0 +1,336 @@ +package sqlx + +// Named Query Support +// +// * BindMap - bind query bindvars to map/struct args +// * NamedExec, NamedQuery - named query w/ struct or map +// * NamedStmt - a pre-compiled named query which is a prepared statement +// +// Internal Interfaces: +// +// * compileNamedQuery - rebind a named query, returning a query and list of names +// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist +// +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strconv" + "unicode" + + "github.com/jmoiron/sqlx/reflectx" +) + +// NamedStmt is a prepared statement that executes named queries. Prepare it +// how you would execute a NamedQuery, but pass in a struct or map when executing. +type NamedStmt struct { + Params []string + QueryString string + Stmt *Stmt +} + +// Close closes the named statement. +func (n *NamedStmt) Close() error { + return n.Stmt.Close() +} + +// Exec executes a named statement using the struct passed. +func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return *new(sql.Result), err + } + return n.Stmt.Exec(args...) +} + +// Query executes a named statement using the struct argument, returning rows. +func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return nil, err + } + return n.Stmt.Query(args...) +} + +// QueryRow executes a named statement against the database. Because sqlx cannot +// create a *sql.Row with an error condition pre-set for binding errors, sqlx +// returns a *sqlx.Row instead. +func (n *NamedStmt) QueryRow(arg interface{}) *Row { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return &Row{err: err} + } + return n.Stmt.QueryRowx(args...) +} + +// MustExec execs a NamedStmt, panicing on error +func (n *NamedStmt) MustExec(arg interface{}) sql.Result { + res, err := n.Exec(arg) + if err != nil { + panic(err) + } + return res +} + +// Queryx using this NamedStmt +func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { + r, err := n.Query(arg) + if err != nil { + return nil, err + } + return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err +} + +// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is +// an alias for QueryRow. +func (n *NamedStmt) QueryRowx(arg interface{}) *Row { + return n.QueryRow(arg) +} + +// Select using this NamedStmt +func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { + rows, err := n.Queryx(arg) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// Get using this NamedStmt +func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { + r := n.QueryRowx(arg) + return r.scanAny(dest, false) +} + +// Unsafe creates an unsafe version of the NamedStmt +func (n *NamedStmt) Unsafe() *NamedStmt { + r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString} + r.Stmt.unsafe = true + return r +} + +// A union interface of preparer and binder, required to be able to prepare +// named statements (as the bindtype must be determined). +type namedPreparer interface { + Preparer + binder +} + +func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { + bindType := BindType(p.DriverName()) + q, args, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return nil, err + } + stmt, err := Preparex(p, q) + if err != nil { + return nil, err + } + return &NamedStmt{ + QueryString: q, + Params: args, + Stmt: stmt, + }, nil +} + +func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { + if maparg, ok := arg.(map[string]interface{}); ok { + return bindMapArgs(names, maparg) + } + return bindArgs(names, arg, m) +} + +// private interface to generate a list of interfaces from a given struct +// type, given a list of names to pull out of the struct. Used by public +// BindStruct interface. +func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { + arglist := make([]interface{}, 0, len(names)) + + // grab the indirected value of arg + v := reflect.ValueOf(arg) + for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { + v = v.Elem() + } + + fields := m.TraversalsByName(v.Type(), names) + for i, t := range fields { + if len(t) == 0 { + return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg) + } + val := reflectx.FieldByIndexesReadOnly(v, t) + arglist = append(arglist, val.Interface()) + } + + return arglist, nil +} + +// like bindArgs, but for maps. +func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) { + arglist := make([]interface{}, 0, len(names)) + + for _, name := range names { + val, ok := arg[name] + if !ok { + return arglist, fmt.Errorf("could not find name %s in %#v", name, arg) + } + arglist = append(arglist, val) + } + return arglist, nil +} + +// bindStruct binds a named parameter query with fields from a struct argument. +// The rules for binding field names to parameter names follow the same +// conventions as for StructScan, including obeying the `db` struct tags. +func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { + bound, names, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return "", []interface{}{}, err + } + + arglist, err := bindArgs(names, arg, m) + if err != nil { + return "", []interface{}{}, err + } + + return bound, arglist, nil +} + +// bindMap binds a named parameter query with a map of arguments. +func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { + bound, names, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return "", []interface{}{}, err + } + + arglist, err := bindMapArgs(names, args) + return bound, arglist, err +} + +// -- Compilation of Named Queries + +// Allow digits and letters in bind params; additionally runes are +// checked against underscores, meaning that bind params can have be +// alphanumeric with underscores. Mind the difference between unicode +// digits and numbers, where '5' is a digit but '五' is not. +var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} + +// FIXME: this function isn't safe for unicode named params, as a failing test +// can testify. This is not a regression but a failure of the original code +// as well. It should be modified to range over runes in a string rather than +// bytes, even though this is less convenient and slower. Hopefully the +// addition of the prepared NamedStmt (which will only do this once) will make +// up for the slightly slower ad-hoc NamedExec/NamedQuery. + +// compile a NamedQuery into an unbound query (using the '?' bindvar) and +// a list of names. +func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { + names = make([]string, 0, 10) + rebound := make([]byte, 0, len(qs)) + + inName := false + last := len(qs) - 1 + currentVar := 1 + name := make([]byte, 0, 10) + + for i, b := range qs { + // a ':' while we're in a name is an error + if b == ':' { + // if this is the second ':' in a '::' escape sequence, append a ':' + if inName && i > 0 && qs[i-1] == ':' { + rebound = append(rebound, ':') + inName = false + continue + } else if inName { + err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) + return query, names, err + } + inName = true + name = []byte{} + // if we're in a name, and this is an allowed character, continue + } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_') && i != last { + // append the byte to the name if we are in a name and not on the last byte + name = append(name, b) + // if we're in a name and it's not an allowed character, the name is done + } else if inName { + inName = false + // if this is the final byte of the string and it is part of the name, then + // make sure to add it to the name + if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { + name = append(name, b) + } + // add the string representation to the names list + names = append(names, string(name)) + // add a proper bindvar for the bindType + switch bindType { + // oracle only supports named type bind vars even for positional + case NAMED: + rebound = append(rebound, ':') + rebound = append(rebound, name...) + case QUESTION, UNKNOWN: + rebound = append(rebound, '?') + case DOLLAR: + rebound = append(rebound, '$') + for _, b := range strconv.Itoa(currentVar) { + rebound = append(rebound, byte(b)) + } + currentVar++ + } + // add this byte to string unless it was not part of the name + if i != last { + rebound = append(rebound, b) + } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { + rebound = append(rebound, b) + } + } else { + // this is a normal byte and should just go onto the rebound query + rebound = append(rebound, b) + } + } + + return string(rebound), names, err +} + +// BindNamed binds a struct or a map to a query with named parameters. +// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. +func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(bindType, query, arg, mapper()) +} + +// Named takes a query using named parameters and an argument and +// returns a new query with a list of args that can be executed by +// a database. The return value uses the `?` bindvar. +func Named(query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(QUESTION, query, arg, mapper()) +} + +func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { + if maparg, ok := arg.(map[string]interface{}); ok { + return bindMap(bindType, query, maparg) + } + return bindStruct(bindType, query, arg, m) +} + +// NamedQuery binds a named query and then runs Query on the result using the +// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with +// map[string]interface{} types. +func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.Queryx(q, args...) +} + +// NamedExec uses BindStruct to get a query executable by the driver and +// then runs Exec on the result. Returns an error from the binding +// or the query excution itself. +func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.Exec(q, args...) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go new file mode 100644 index 0000000..5728011 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/reflectx/reflect.go @@ -0,0 +1,371 @@ +// Package reflectx implements extensions to the standard reflect lib suitable +// for implementing marshaling and unmarshaling packages. The main Mapper type +// allows for Go-compatible named atribute access, including accessing embedded +// struct attributes and the ability to use functions and struct tags to +// customize field names. +// +package reflectx + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "sync" +) + +// A FieldInfo is a collection of metadata about a struct field. +type FieldInfo struct { + Index []int + Path string + Field reflect.StructField + Zero reflect.Value + Name string + Options map[string]string + Embedded bool + Children []*FieldInfo + Parent *FieldInfo +} + +// A StructMap is an index of field metadata for a struct. +type StructMap struct { + Tree *FieldInfo + Index []*FieldInfo + Paths map[string]*FieldInfo + Names map[string]*FieldInfo +} + +// GetByPath returns a *FieldInfo for a given string path. +func (f StructMap) GetByPath(path string) *FieldInfo { + return f.Paths[path] +} + +// GetByTraversal returns a *FieldInfo for a given integer path. It is +// analagous to reflect.FieldByIndex. +func (f StructMap) GetByTraversal(index []int) *FieldInfo { + if len(index) == 0 { + return nil + } + + tree := f.Tree + for _, i := range index { + if i >= len(tree.Children) || tree.Children[i] == nil { + return nil + } + tree = tree.Children[i] + } + return tree +} + +// Mapper is a general purpose mapper of names to struct fields. A Mapper +// behaves like most marshallers, optionally obeying a field tag for name +// mapping and a function to provide a basic mapping of fields to names. +type Mapper struct { + cache map[reflect.Type]*StructMap + tagName string + tagMapFunc func(string) string + mapFunc func(string) string + mutex sync.Mutex +} + +// NewMapper returns a new mapper which optionally obeys the field tag given +// by tagName. If tagName is the empty string, it is ignored. +func NewMapper(tagName string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + } +} + +// NewMapperTagFunc returns a new mapper which contains a mapper for field names +// AND a mapper for tag values. This is useful for tags like json which can +// have values like "name,omitempty". +func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: mapFunc, + tagMapFunc: tagMapFunc, + } +} + +// NewMapperFunc returns a new mapper which optionally obeys a field tag and +// a struct field name mapper func given by f. Tags will take precedence, but +// for any other field, the mapped name will be f(field.Name) +func NewMapperFunc(tagName string, f func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: f, + } +} + +// TypeMap returns a mapping of field strings to int slices representing +// the traversal down the struct to reach the field. +func (m *Mapper) TypeMap(t reflect.Type) *StructMap { + m.mutex.Lock() + mapping, ok := m.cache[t] + if !ok { + mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) + m.cache[t] = mapping + } + m.mutex.Unlock() + return mapping +} + +// FieldMap returns the mapper's mapping of field names to reflect values. Panics +// if v's Kind is not Struct, or v is not Indirectable to a struct kind. +func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + r := map[string]reflect.Value{} + tm := m.TypeMap(v.Type()) + for tagName, fi := range tm.Names { + r[tagName] = FieldByIndexes(v, fi.Index) + } + return r +} + +// FieldByName returns a field by the its mapped name as a reflect.Value. +// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. +// Returns zero Value if the name is not found. +func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + fi, ok := tm.Names[name] + if !ok { + return v + } + return FieldByIndexes(v, fi.Index) +} + +// FieldsByName returns a slice of values corresponding to the slice of names +// for the value. Panics if v's Kind is not Struct or v is not Indirectable +// to a struct Kind. Returns zero Value for each name not found. +func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + vals := make([]reflect.Value, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + vals = append(vals, *new(reflect.Value)) + } else { + vals = append(vals, FieldByIndexes(v, fi.Index)) + } + } + return vals +} + +// TraversalsByName returns a slice of int slices which represent the struct +// traversals for each mapped name. Panics if t is not a struct or Indirectable +// to a struct. Returns empty int slice for each name not found. +func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { + t = Deref(t) + mustBe(t, reflect.Struct) + tm := m.TypeMap(t) + + r := make([][]int, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + r = append(r, []int{}) + } else { + r = append(r, fi.Index) + } + } + return r +} + +// FieldByIndexes returns a value for a particular struct traversal. +func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + // if this is a pointer, it's possible it is nil + if v.Kind() == reflect.Ptr && v.IsNil() { + alloc := reflect.New(Deref(v.Type())) + v.Set(alloc) + } + if v.Kind() == reflect.Map && v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + } + return v +} + +// FieldByIndexesReadOnly returns a value for a particular struct traversal, +// but is not concerned with allocating nil pointers because the value is +// going to be used for reading and not setting. +func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + } + return v +} + +// Deref is Indirect for reflect.Types +func Deref(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +// -- helpers & utilities -- + +type kinder interface { + Kind() reflect.Kind +} + +// mustBe checks a value against a kind, panicing with a reflect.ValueError +// if the kind isn't that which is required. +func mustBe(v kinder, expected reflect.Kind) { + k := v.Kind() + if k != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: k}) + } +} + +// methodName is returns the caller of the function calling methodName +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +type typeQueue struct { + t reflect.Type + fi *FieldInfo + pp string // Parent path +} + +// A copying append that creates a new slice each time. +func apnd(is []int, i int) []int { + x := make([]int, len(is)+1) + for p, n := range is { + x[p] = n + } + x[len(x)-1] = i + return x +} + +// getMapping returns a mapping for the t type, using the tagName, mapFunc and +// tagMapFunc to determine the canonical names of fields. +func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap { + m := []*FieldInfo{} + + root := &FieldInfo{} + queue := []typeQueue{} + queue = append(queue, typeQueue{Deref(t), root, ""}) + + for len(queue) != 0 { + // pop the first item off of the queue + tq := queue[0] + queue = queue[1:] + nChildren := 0 + if tq.t.Kind() == reflect.Struct { + nChildren = tq.t.NumField() + } + tq.fi.Children = make([]*FieldInfo, nChildren) + + // iterate through all of its fields + for fieldPos := 0; fieldPos < nChildren; fieldPos++ { + f := tq.t.Field(fieldPos) + + fi := FieldInfo{} + fi.Field = f + fi.Zero = reflect.New(f.Type).Elem() + fi.Options = map[string]string{} + + var tag, name string + if tagName != "" && strings.Contains(string(f.Tag), tagName+":") { + tag = f.Tag.Get(tagName) + name = tag + } else { + if mapFunc != nil { + name = mapFunc(f.Name) + } + } + + parts := strings.Split(name, ",") + if len(parts) > 1 { + name = parts[0] + for _, opt := range parts[1:] { + kv := strings.Split(opt, "=") + if len(kv) > 1 { + fi.Options[kv[0]] = kv[1] + } else { + fi.Options[kv[0]] = "" + } + } + } + + if tagMapFunc != nil { + tag = tagMapFunc(tag) + } + + fi.Name = name + + if tq.pp == "" || (tq.pp == "" && tag == "") { + fi.Path = fi.Name + } else { + fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name) + } + + // if the name is "-", disabled via a tag, skip it + if name == "-" { + continue + } + + // skip unexported fields + if len(f.PkgPath) != 0 { + continue + } + + // bfs search of anonymous embedded structs + if f.Anonymous { + pp := tq.pp + if tag != "" { + pp = fi.Path + } + + fi.Embedded = true + fi.Index = apnd(tq.fi.Index, fieldPos) + nChildren := 0 + ft := Deref(f.Type) + if ft.Kind() == reflect.Struct { + nChildren = ft.NumField() + } + fi.Children = make([]*FieldInfo, nChildren) + queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) + } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) + queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) + } + + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Parent = tq.fi + tq.fi.Children[fieldPos] = &fi + m = append(m, &fi) + } + } + + flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} + for _, fi := range flds.Index { + flds.Paths[fi.Path] = fi + if fi.Name != "" && !fi.Embedded { + flds.Names[fi.Path] = fi + } + } + + return flds +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/sqlx.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/sqlx.go new file mode 100644 index 0000000..b1ba4cf --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/jmoiron/sqlx/sqlx.go @@ -0,0 +1,992 @@ +package sqlx + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + + "io/ioutil" + "path/filepath" + "reflect" + "strings" + + "github.com/jmoiron/sqlx/reflectx" +) + +// Although the NameMapper is convenient, in practice it should not +// be relied on except for application code. If you are writing a library +// that uses sqlx, you should be aware that the name mappings you expect +// can be overridded by your user's application. + +// NameMapper is used to map column names to struct field names. By default, +// it uses strings.ToLower to lowercase struct field names. It can be set +// to whatever you want, but it is encouraged to be set before sqlx is used +// as name-to-field mappings are cached after first use on a type. +var NameMapper = strings.ToLower +var origMapper = reflect.ValueOf(NameMapper) + +// Rather than creating on init, this is created when necessary so that +// importers have time to customize the NameMapper. +var mpr *reflectx.Mapper + +// mapper returns a valid mapper using the configured NameMapper func. +func mapper() *reflectx.Mapper { + if mpr == nil { + mpr = reflectx.NewMapperFunc("db", NameMapper) + } else if origMapper != reflect.ValueOf(NameMapper) { + // if NameMapper has changed, create a new mapper + mpr = reflectx.NewMapperFunc("db", NameMapper) + origMapper = reflect.ValueOf(NameMapper) + } + return mpr +} + +// isScannable takes the reflect.Type and the actual dest value and returns +// whether or not it's Scannable. Something is scannable if: +// * it is not a struct +// * it implements sql.Scanner +// * it has no exported fields +func isScannable(t reflect.Type) bool { + if reflect.PtrTo(t).Implements(_scannerInterface) { + return true + } + if t.Kind() != reflect.Struct { + return true + } + + // it's not important that we use the right mapper for this particular object, + // we're only concerned on how many exported fields this struct has + m := mapper() + if len(m.TypeMap(t).Index) == 0 { + return true + } + return false +} + +// ColScanner is an interface used by MapScan and SliceScan +type ColScanner interface { + Columns() ([]string, error) + Scan(dest ...interface{}) error + Err() error +} + +// Queryer is an interface used by Get and Select +type Queryer interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + Queryx(query string, args ...interface{}) (*Rows, error) + QueryRowx(query string, args ...interface{}) *Row +} + +// Execer is an interface used by MustExec and LoadFile +type Execer interface { + Exec(query string, args ...interface{}) (sql.Result, error) +} + +// Binder is an interface for something which can bind queries (Tx, DB) +type binder interface { + DriverName() string + Rebind(string) string + BindNamed(string, interface{}) (string, []interface{}, error) +} + +// Ext is a union interface which can bind, query, and exec, used by +// NamedQuery and NamedExec. +type Ext interface { + binder + Queryer + Execer +} + +// Preparer is an interface used by Preparex. +type Preparer interface { + Prepare(query string) (*sql.Stmt, error) +} + +// determine if any of our extensions are unsafe +func isUnsafe(i interface{}) bool { + switch v := i.(type) { + case Row: + return v.unsafe + case *Row: + return v.unsafe + case Rows: + return v.unsafe + case *Rows: + return v.unsafe + case NamedStmt: + return v.Stmt.unsafe + case *NamedStmt: + return v.Stmt.unsafe + case Stmt: + return v.unsafe + case *Stmt: + return v.unsafe + case qStmt: + return v.unsafe + case *qStmt: + return v.unsafe + case DB: + return v.unsafe + case *DB: + return v.unsafe + case Tx: + return v.unsafe + case *Tx: + return v.unsafe + case sql.Rows, *sql.Rows: + return false + default: + return false + } +} + +func mapperFor(i interface{}) *reflectx.Mapper { + switch i.(type) { + case DB: + return i.(DB).Mapper + case *DB: + return i.(*DB).Mapper + case Tx: + return i.(Tx).Mapper + case *Tx: + return i.(*Tx).Mapper + default: + return mapper() + } +} + +var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// Row is a reimplementation of sql.Row in order to gain access to the underlying +// sql.Rows.Columns() data, necessary for StructScan. +type Row struct { + err error + unsafe bool + rows *sql.Rows + Mapper *reflectx.Mapper +} + +// Scan is a fixed implementation of sql.Row.Scan, which does not discard the +// underlying error from the internal rows object if it exists. +func (r *Row) Scan(dest ...interface{}) error { + if r.err != nil { + return r.err + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + defer r.rows.Close() + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !r.rows.Next() { + if err := r.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := r.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + if err := r.rows.Close(); err != nil { + return err + } + return nil +} + +// Columns returns the underlying sql.Rows.Columns(), or the deferred error usually +// returned by Row.Scan() +func (r *Row) Columns() ([]string, error) { + if r.err != nil { + return []string{}, r.err + } + return r.rows.Columns() +} + +// Err returns the error encountered while scanning. +func (r *Row) Err() error { + return r.err +} + +// DB is a wrapper around sql.DB which keeps track of the driverName upon Open, +// used mostly to automatically bind named queries using the right bindvars. +type DB struct { + *sql.DB + driverName string + unsafe bool + Mapper *reflectx.Mapper +} + +// NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The +// driverName of the original database is required for named query support. +func NewDb(db *sql.DB, driverName string) *DB { + return &DB{DB: db, driverName: driverName, Mapper: mapper()} +} + +// DriverName returns the driverName passed to the Open function for this DB. +func (db *DB) DriverName() string { + return db.driverName +} + +// Open is the same as sql.Open, but returns an *sqlx.DB instead. +func Open(driverName, dataSourceName string) (*DB, error) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err +} + +// MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error. +func MustOpen(driverName, dataSourceName string) *DB { + db, err := Open(driverName, dataSourceName) + if err != nil { + panic(err) + } + return db +} + +// MapperFunc sets a new mapper for this db using the default sqlx struct tag +// and the provided mapper function. +func (db *DB) MapperFunc(mf func(string) string) { + db.Mapper = reflectx.NewMapperFunc("db", mf) +} + +// Rebind transforms a query from QUESTION to the DB driver's bindvar type. +func (db *DB) Rebind(query string) string { + return Rebind(BindType(db.driverName), query) +} + +// Unsafe returns a version of DB which will silently succeed to scan when +// columns in the SQL result have no fields in the destination struct. +// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its +// safety behavior. +func (db *DB) Unsafe() *DB { + return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper} +} + +// BindNamed binds a query using the DB driver's bindvar type. +func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper) +} + +// NamedQuery using this DB. +func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { + return NamedQuery(db, query, arg) +} + +// NamedExec using this DB. +func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { + return NamedExec(db, query, arg) +} + +// Select using this DB. +func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { + return Select(db, dest, query, args...) +} + +// Get using this DB. +func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { + return Get(db, dest, query, args...) +} + +// MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead +// of an *sql.Tx. +func (db *DB) MustBegin() *Tx { + tx, err := db.Beginx() + if err != nil { + panic(err) + } + return tx +} + +// Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx. +func (db *DB) Beginx() (*Tx, error) { + tx, err := db.DB.Begin() + if err != nil { + return nil, err + } + return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// Queryx queries the database and returns an *sqlx.Rows. +func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { + r, err := db.DB.Query(query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// QueryRowx queries the database and returns an *sqlx.Row. +func (db *DB) QueryRowx(query string, args ...interface{}) *Row { + rows, err := db.DB.Query(query, args...) + return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} +} + +// MustExec (panic) runs MustExec using this database. +func (db *DB) MustExec(query string, args ...interface{}) sql.Result { + return MustExec(db, query, args...) +} + +// Preparex returns an sqlx.Stmt instead of a sql.Stmt +func (db *DB) Preparex(query string) (*Stmt, error) { + return Preparex(db, query) +} + +// PrepareNamed returns an sqlx.NamedStmt +func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { + return prepareNamed(db, query) +} + +// Tx is an sqlx wrapper around sql.Tx with extra functionality +type Tx struct { + *sql.Tx + driverName string + unsafe bool + Mapper *reflectx.Mapper +} + +// DriverName returns the driverName used by the DB which began this transaction. +func (tx *Tx) DriverName() string { + return tx.driverName +} + +// Rebind a query within a transaction's bindvar type. +func (tx *Tx) Rebind(query string) string { + return Rebind(BindType(tx.driverName), query) +} + +// Unsafe returns a version of Tx which will silently succeed to scan when +// columns in the SQL result have no fields in the destination struct. +func (tx *Tx) Unsafe() *Tx { + return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper} +} + +// BindNamed binds a query within a transaction's bindvar type. +func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) { + return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper) +} + +// NamedQuery within a transaction. +func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { + return NamedQuery(tx, query, arg) +} + +// NamedExec a named query within a transaction. +func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { + return NamedExec(tx, query, arg) +} + +// Select within a transaction. +func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { + return Select(tx, dest, query, args...) +} + +// Queryx within a transaction. +func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { + r, err := tx.Tx.Query(query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err +} + +// QueryRowx within a transaction. +func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { + rows, err := tx.Tx.Query(query, args...) + return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} +} + +// Get within a transaction. +func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { + return Get(tx, dest, query, args...) +} + +// MustExec runs MustExec within a transaction. +func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { + return MustExec(tx, query, args...) +} + +// Preparex a statement within a transaction. +func (tx *Tx) Preparex(query string) (*Stmt, error) { + return Preparex(tx, query) +} + +// Stmtx returns a version of the prepared statement which runs within a transaction. Provided +// stmt can be either *sql.Stmt or *sqlx.Stmt. +func (tx *Tx) Stmtx(stmt interface{}) *Stmt { + var s *sql.Stmt + switch v := stmt.(type) { + case Stmt: + s = v.Stmt + case *Stmt: + s = v.Stmt + case sql.Stmt: + s = &v + case *sql.Stmt: + s = v + default: + panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) + } + return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper} +} + +// NamedStmt returns a version of the prepared statement which runs within a transaction. +func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt { + return &NamedStmt{ + QueryString: stmt.QueryString, + Params: stmt.Params, + Stmt: tx.Stmtx(stmt.Stmt), + } +} + +// PrepareNamed returns an sqlx.NamedStmt +func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) { + return prepareNamed(tx, query) +} + +// Stmt is an sqlx wrapper around sql.Stmt with extra functionality +type Stmt struct { + *sql.Stmt + unsafe bool + Mapper *reflectx.Mapper +} + +// Unsafe returns a version of Stmt which will silently succeed to scan when +// columns in the SQL result have no fields in the destination struct. +func (s *Stmt) Unsafe() *Stmt { + return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper} +} + +// Select using the prepared statement. +func (s *Stmt) Select(dest interface{}, args ...interface{}) error { + return Select(&qStmt{s}, dest, "", args...) +} + +// Get using the prepared statement. +func (s *Stmt) Get(dest interface{}, args ...interface{}) error { + return Get(&qStmt{s}, dest, "", args...) +} + +// MustExec (panic) using this statement. Note that the query portion of the error +// output will be blank, as Stmt does not expose its query. +func (s *Stmt) MustExec(args ...interface{}) sql.Result { + return MustExec(&qStmt{s}, "", args...) +} + +// QueryRowx using this statement. +func (s *Stmt) QueryRowx(args ...interface{}) *Row { + qs := &qStmt{s} + return qs.QueryRowx("", args...) +} + +// Queryx using this statement. +func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { + qs := &qStmt{s} + return qs.Queryx("", args...) +} + +// qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by +// implementing those interfaces and ignoring the `query` argument. +type qStmt struct{ *Stmt } + +func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) { + return q.Stmt.Query(args...) +} + +func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { + r, err := q.Stmt.Query(args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err +} + +func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row { + rows, err := q.Stmt.Query(args...) + return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} +} + +func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) { + return q.Stmt.Exec(args...) +} + +// Rows is a wrapper around sql.Rows which caches costly reflect operations +// during a looped StructScan +type Rows struct { + *sql.Rows + unsafe bool + Mapper *reflectx.Mapper + // these fields cache memory use for a rows during iteration w/ structScan + started bool + fields [][]int + values []interface{} +} + +// SliceScan using this Rows. +func (r *Rows) SliceScan() ([]interface{}, error) { + return SliceScan(r) +} + +// MapScan using this Rows. +func (r *Rows) MapScan(dest map[string]interface{}) error { + return MapScan(r, dest) +} + +// StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct. +// Use this and iterate over Rows manually when the memory load of Select() might be +// prohibitive. *Rows.StructScan caches the reflect work of matching up column +// positions to fields to avoid that overhead per scan, which means it is not safe +// to run StructScan on the same Rows instance with different struct types. +func (r *Rows) StructScan(dest interface{}) error { + v := reflect.ValueOf(dest) + + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + + v = reflect.Indirect(v) + + if !r.started { + columns, err := r.Columns() + if err != nil { + return err + } + m := r.Mapper + + r.fields = m.TraversalsByName(v.Type(), columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(r.fields); err != nil && !r.unsafe { + return fmt.Errorf("missing destination name %s", columns[f]) + } + r.values = make([]interface{}, len(columns)) + r.started = true + } + + err := fieldsByTraversal(v, r.fields, r.values, true) + if err != nil { + return err + } + // scan into the struct field pointers and append to our results + err = r.Scan(r.values...) + if err != nil { + return err + } + return r.Err() +} + +// Connect to a database and verify with a ping. +func Connect(driverName, dataSourceName string) (*DB, error) { + db, err := Open(driverName, dataSourceName) + if err != nil { + return db, err + } + err = db.Ping() + return db, err +} + +// MustConnect connects to a database and panics on error. +func MustConnect(driverName, dataSourceName string) *DB { + db, err := Connect(driverName, dataSourceName) + if err != nil { + panic(err) + } + return db +} + +// Preparex prepares a statement. +func Preparex(p Preparer, query string) (*Stmt, error) { + s, err := p.Prepare(query) + if err != nil { + return nil, err + } + return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err +} + +// Select executes a query using the provided Queryer, and StructScans each row +// into dest, which must be a slice. If the slice elements are scannable, then +// the result set must have only one column. Otherwise, StructScan is used. +// The *sql.Rows are closed automatically. +func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { + rows, err := q.Queryx(query, args...) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// Get does a QueryRow using the provided Queryer, and scans the resulting row +// to dest. If dest is scannable, the result must only have one column. Otherwise, +// StructScan is used. Get will return sql.ErrNoRows like row.Scan would. +func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { + r := q.QueryRowx(query, args...) + return r.scanAny(dest, false) +} + +// LoadFile exec's every statement in a file (as a single call to Exec). +// LoadFile may return a nil *sql.Result if errors are encountered locating or +// reading the file at path. LoadFile reads the entire file into memory, so it +// is not suitable for loading large data dumps, but can be useful for initializing +// schemas or loading indexes. +// +// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 +// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting +// this by requiring something with DriverName() and then attempting to split the +// queries will be difficult to get right, and its current driver-specific behavior +// is deemed at least not complex in its incorrectness. +func LoadFile(e Execer, path string) (*sql.Result, error) { + realpath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + contents, err := ioutil.ReadFile(realpath) + if err != nil { + return nil, err + } + res, err := e.Exec(string(contents)) + return &res, err +} + +// MustExec execs the query using e and panics if there was an error. +func MustExec(e Execer, query string, args ...interface{}) sql.Result { + res, err := e.Exec(query, args...) + if err != nil { + panic(err) + } + return res +} + +// SliceScan using this Rows. +func (r *Row) SliceScan() ([]interface{}, error) { + return SliceScan(r) +} + +// MapScan using this Rows. +func (r *Row) MapScan(dest map[string]interface{}) error { + return MapScan(r, dest) +} + +func (r *Row) scanAny(dest interface{}, structOnly bool) error { + if r.err != nil { + return r.err + } + defer r.rows.Close() + + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + if v.IsNil() { + return errors.New("nil pointer passed to StructScan destination") + } + + base := reflectx.Deref(v.Type()) + scannable := isScannable(base) + + if structOnly && scannable { + return structOnlyError(base) + } + + columns, err := r.Columns() + if err != nil { + return err + } + + if scannable && len(columns) > 1 { + return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns)) + } + + if scannable { + return r.Scan(dest) + } + + m := r.Mapper + + fields := m.TraversalsByName(v.Type(), columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(fields); err != nil && !r.unsafe { + return fmt.Errorf("missing destination name %s", columns[f]) + } + values := make([]interface{}, len(columns)) + + err = fieldsByTraversal(v, fields, values, true) + if err != nil { + return err + } + // scan into the struct field pointers and append to our results + return r.Scan(values...) +} + +// StructScan a single Row into dest. +func (r *Row) StructScan(dest interface{}) error { + return r.scanAny(dest, true) +} + +// SliceScan a row, returning a []interface{} with values similar to MapScan. +// This function is primarly intended for use where the number of columns +// is not known. Because you can pass an []interface{} directly to Scan, +// it's recommended that you do that as it will not have to allocate new +// slices per row. +func SliceScan(r ColScanner) ([]interface{}, error) { + // ignore r.started, since we needn't use reflect for anything. + columns, err := r.Columns() + if err != nil { + return []interface{}{}, err + } + + values := make([]interface{}, len(columns)) + for i := range values { + values[i] = new(interface{}) + } + + err = r.Scan(values...) + + if err != nil { + return values, err + } + + for i := range columns { + values[i] = *(values[i].(*interface{})) + } + + return values, r.Err() +} + +// MapScan scans a single Row into the dest map[string]interface{}. +// Use this to get results for SQL that might not be under your control +// (for instance, if you're building an interface for an SQL server that +// executes SQL from input). Please do not use this as a primary interface! +// This will modify the map sent to it in place, so reuse the same map with +// care. Columns which occur more than once in the result will overwrite +// eachother! +func MapScan(r ColScanner, dest map[string]interface{}) error { + // ignore r.started, since we needn't use reflect for anything. + columns, err := r.Columns() + if err != nil { + return err + } + + values := make([]interface{}, len(columns)) + for i := range values { + values[i] = new(interface{}) + } + + err = r.Scan(values...) + if err != nil { + return err + } + + for i, column := range columns { + dest[column] = *(values[i].(*interface{})) + } + + return r.Err() +} + +type rowsi interface { + Close() error + Columns() ([]string, error) + Err() error + Next() bool + Scan(...interface{}) error +} + +// structOnlyError returns an error appropriate for type when a non-scannable +// struct is expected but something else is given +func structOnlyError(t reflect.Type) error { + isStruct := t.Kind() == reflect.Struct + isScanner := reflect.PtrTo(t).Implements(_scannerInterface) + if !isStruct { + return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) + } + if isScanner { + return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name()) + } + return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) +} + +// scanAll scans all rows into a destination, which must be a slice of any +// type. If the destination slice type is a Struct, then StructScan will be +// used on each row. If the destination is some other kind of base type, then +// each row must only have one column which can scan into that type. This +// allows you to do something like: +// +// rows, _ := db.Query("select id from people;") +// var ids []int +// scanAll(rows, &ids, false) +// +// and ids will be a list of the id results. I realize that this is a desirable +// interface to expose to users, but for now it will only be exposed via changes +// to `Get` and `Select`. The reason that this has been implemented like this is +// this is the only way to not duplicate reflect work in the new API while +// maintaining backwards compatibility. +func scanAll(rows rowsi, dest interface{}, structOnly bool) error { + var v, vp reflect.Value + + value := reflect.ValueOf(dest) + + // json.Unmarshal returns errors for these + if value.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + if value.IsNil() { + return errors.New("nil pointer passed to StructScan destination") + } + direct := reflect.Indirect(value) + + slice, err := baseType(value.Type(), reflect.Slice) + if err != nil { + return err + } + + isPtr := slice.Elem().Kind() == reflect.Ptr + base := reflectx.Deref(slice.Elem()) + scannable := isScannable(base) + + if structOnly && scannable { + return structOnlyError(base) + } + + columns, err := rows.Columns() + if err != nil { + return err + } + + // if it's a base type make sure it only has 1 column; if not return an error + if scannable && len(columns) > 1 { + return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns)) + } + + if !scannable { + var values []interface{} + var m *reflectx.Mapper + + switch rows.(type) { + case *Rows: + m = rows.(*Rows).Mapper + default: + m = mapper() + } + + fields := m.TraversalsByName(base, columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { + return fmt.Errorf("missing destination name %s", columns[f]) + } + values = make([]interface{}, len(columns)) + + for rows.Next() { + // create a new struct type (which returns PtrTo) and indirect it + vp = reflect.New(base) + v = reflect.Indirect(vp) + + err = fieldsByTraversal(v, fields, values, true) + + // scan into the struct field pointers and append to our results + err = rows.Scan(values...) + if err != nil { + return err + } + + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, v)) + } + } + } else { + for rows.Next() { + vp = reflect.New(base) + err = rows.Scan(vp.Interface()) + // append + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, reflect.Indirect(vp))) + } + } + } + + return rows.Err() +} + +// FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately +// it doesn't really feel like it's named properly. There is an incongruency +// between this and the way that StructScan (which might better be ScanStruct +// anyway) works on a rows object. + +// StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. +// StructScan will scan in the entire rows result, so if you need do not want to +// allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. +// If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. +func StructScan(rows rowsi, dest interface{}) error { + return scanAll(rows, dest, true) + +} + +// reflect helpers + +func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { + t = reflectx.Deref(t) + if t.Kind() != expected { + return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) + } + return t, nil +} + +// fieldsByName fills a values interface with fields from the passed value based +// on the traversals in int. If ptrs is true, return addresses instead of values. +// We write this instead of using FieldsByName to save allocations and map lookups +// when iterating over many rows. Empty traversals will get an interface pointer. +// Because of the necessity of requesting ptrs or values, it's considered a bit too +// specialized for inclusion in reflectx itself. +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return errors.New("argument not a struct") + } + + for i, traversal := range traversals { + if len(traversal) == 0 { + values[i] = new(interface{}) + continue + } + f := reflectx.FieldByIndexes(v, traversal) + if ptrs { + values[i] = f.Addr().Interface() + } else { + values[i] = f.Interface() + } + } + return nil +} + +func missingFields(transversals [][]int) (field int, err error) { + for i, t := range transversals { + if len(t) == 0 { + return i, errors.New("missing field") + } + } + return 0, nil +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/LICENSE new file mode 100644 index 0000000..ade9307 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/LICENSE @@ -0,0 +1,191 @@ +All files in this repository are licensed as follows. If you contribute +to this repository, it is assumed that you license your contribution +under the same license unless you state otherwise. + +All files Copyright (C) 2015 Canonical Ltd. unless otherwise specified in the file. + +This software is licensed under the LGPLv3, included below. + +As a special exception to the GNU Lesser General Public License version 3 +("LGPL3"), the copyright holders of this Library give you permission to +convey to a third party a Combined Work that links statically or dynamically +to this Library without providing any Minimal Corresponding Source or +Minimal Application Code as set out in 4d or providing the installation +information set out in section 4e, provided that you comply with the other +provisions of LGPL3 and provided that you meet, for the Application the +terms and conditions of the license(s) which apply to the Application. + +Except as stated in this special exception, the provisions of LGPL3 will +continue to comply in full to this Library. If you modify this Library, you +may apply this exception to your version of this Library, but you are not +obliged to do so. If you do not wish to do so, delete this exception +statement from your version. This exception does not (and cannot) modify any +license terms which apply to the Application, with which you must still +comply. + + + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/doc.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/doc.go new file mode 100644 index 0000000..35b119a --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/doc.go @@ -0,0 +1,81 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +/* +[godoc-link-here] + +The juju/errors provides an easy way to annotate errors without losing the +orginal error context. + +The exported `New` and `Errorf` functions are designed to replace the +`errors.New` and `fmt.Errorf` functions respectively. The same underlying +error is there, but the package also records the location at which the error +was created. + +A primary use case for this library is to add extra context any time an +error is returned from a function. + + if err := SomeFunc(); err != nil { + return err + } + +This instead becomes: + + if err := SomeFunc(); err != nil { + return errors.Trace(err) + } + +which just records the file and line number of the Trace call, or + + if err := SomeFunc(); err != nil { + return errors.Annotate(err, "more context") + } + +which also adds an annotation to the error. + +When you want to check to see if an error is of a particular type, a helper +function is normally exported by the package that returned the error, like the +`os` package does. The underlying cause of the error is available using the +`Cause` function. + + os.IsNotExist(errors.Cause(err)) + +The result of the `Error()` call on an annotated error is the annotations joined +with colons, then the result of the `Error()` method for the underlying error +that was the cause. + + err := errors.Errorf("original") + err = errors.Annotatef(err, "context") + err = errors.Annotatef(err, "more context") + err.Error() -> "more context: context: original" + +Obviously recording the file, line and functions is not very useful if you +cannot get them back out again. + + errors.ErrorStack(err) + +will return something like: + + first error + github.com/juju/errors/annotation_test.go:193: + github.com/juju/errors/annotation_test.go:194: annotation + github.com/juju/errors/annotation_test.go:195: + github.com/juju/errors/annotation_test.go:196: more context + github.com/juju/errors/annotation_test.go:197: + +The first error was generated by an external system, so there was no location +associated. The second, fourth, and last lines were generated with Trace calls, +and the other two through Annotate. + +Sometimes when responding to an error you want to return a more specific error +for the situation. + + if err := FindField(field); err != nil { + return errors.Wrap(err, errors.NotFoundf(field)) + } + +This returns an error where the complete error stack is still available, and +`errors.Cause()` will return the `NotFound` error. + +*/ +package errors diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/error.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/error.go new file mode 100644 index 0000000..8c51c45 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/error.go @@ -0,0 +1,145 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "fmt" + "reflect" + "runtime" +) + +// Err holds a description of an error along with information about +// where the error was created. +// +// It may be embedded in custom error types to add extra information that +// this errors package can understand. +type Err struct { + // message holds an annotation of the error. + message string + + // cause holds the cause of the error as returned + // by the Cause method. + cause error + + // previous holds the previous error in the error stack, if any. + previous error + + // file and line hold the source code location where the error was + // created. + file string + line int +} + +// NewErr is used to return an Err for the purpose of embedding in other +// structures. The location is not specified, and needs to be set with a call +// to SetLocation. +// +// For example: +// type FooError struct { +// errors.Err +// code int +// } +// +// func NewFooError(code int) error { +// err := &FooError{errors.NewErr("foo"), code} +// err.SetLocation(1) +// return err +// } +func NewErr(format string, args ...interface{}) Err { + return Err{ + message: fmt.Sprintf(format, args...), + } +} + +// NewErrWithCause is used to return an Err with case by other error for the purpose of embedding in other +// structures. The location is not specified, and needs to be set with a call +// to SetLocation. +// +// For example: +// type FooError struct { +// errors.Err +// code int +// } +// +// func (e *FooError) Annotate(format string, args ...interface{}) error { +// err := &FooError{errors.NewErrWithCause(e.Err, format, args...), e.code} +// err.SetLocation(1) +// return err +// }) +func NewErrWithCause(other error, format string, args ...interface{}) Err { + return Err{ + message: fmt.Sprintf(format, args...), + cause: Cause(other), + previous: other, + } +} + +// Location is the file and line of where the error was most recently +// created or annotated. +func (e *Err) Location() (filename string, line int) { + return e.file, e.line +} + +// Underlying returns the previous error in the error stack, if any. A client +// should not ever really call this method. It is used to build the error +// stack and should not be introspected by client calls. Or more +// specifically, clients should not depend on anything but the `Cause` of an +// error. +func (e *Err) Underlying() error { + return e.previous +} + +// The Cause of an error is the most recent error in the error stack that +// meets one of these criteria: the original error that was raised; the new +// error that was passed into the Wrap function; the most recently masked +// error; or nil if the error itself is considered the Cause. Normally this +// method is not invoked directly, but instead through the Cause stand alone +// function. +func (e *Err) Cause() error { + return e.cause +} + +// Message returns the message stored with the most recent location. This is +// the empty string if the most recent call was Trace, or the message stored +// with Annotate or Mask. +func (e *Err) Message() string { + return e.message +} + +// Error implements error.Error. +func (e *Err) Error() string { + // We want to walk up the stack of errors showing the annotations + // as long as the cause is the same. + err := e.previous + if !sameError(Cause(err), e.cause) && e.cause != nil { + err = e.cause + } + switch { + case err == nil: + return e.message + case e.message == "": + return err.Error() + } + return fmt.Sprintf("%s: %v", e.message, err) +} + +// SetLocation records the source location of the error at callDepth stack +// frames above the call. +func (e *Err) SetLocation(callDepth int) { + _, file, line, _ := runtime.Caller(callDepth + 1) + e.file = trimGoPath(file) + e.line = line +} + +// StackTrace returns one string for each location recorded in the stack of +// errors. The first value is the originating error, with a line for each +// other annotation or tracing of the error. +func (e *Err) StackTrace() []string { + return errorStack(e) +} + +// Ideally we'd have a way to check identity, but deep equals will do. +func sameError(e1, e2 error) bool { + return reflect.DeepEqual(e1, e2) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/errortypes.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/errortypes.go new file mode 100644 index 0000000..10b3b19 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/errortypes.go @@ -0,0 +1,284 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "fmt" +) + +// wrap is a helper to construct an *wrapper. +func wrap(err error, format, suffix string, args ...interface{}) Err { + newErr := Err{ + message: fmt.Sprintf(format+suffix, args...), + previous: err, + } + newErr.SetLocation(2) + return newErr +} + +// notFound represents an error when something has not been found. +type notFound struct { + Err +} + +// NotFoundf returns an error which satisfies IsNotFound(). +func NotFoundf(format string, args ...interface{}) error { + return ¬Found{wrap(nil, format, " not found", args...)} +} + +// NewNotFound returns an error which wraps err that satisfies +// IsNotFound(). +func NewNotFound(err error, msg string) error { + return ¬Found{wrap(err, msg, "")} +} + +// IsNotFound reports whether err was created with NotFoundf() or +// NewNotFound(). +func IsNotFound(err error) bool { + err = Cause(err) + _, ok := err.(*notFound) + return ok +} + +// userNotFound represents an error when an inexistent user is looked up. +type userNotFound struct { + Err +} + +// UserNotFoundf returns an error which satisfies IsUserNotFound(). +func UserNotFoundf(format string, args ...interface{}) error { + return &userNotFound{wrap(nil, format, " user not found", args...)} +} + +// NewUserNotFound returns an error which wraps err and satisfies +// IsUserNotFound(). +func NewUserNotFound(err error, msg string) error { + return &userNotFound{wrap(err, msg, "")} +} + +// IsUserNotFound reports whether err was created with UserNotFoundf() or +// NewUserNotFound(). +func IsUserNotFound(err error) bool { + err = Cause(err) + _, ok := err.(*userNotFound) + return ok +} + +// unauthorized represents an error when an operation is unauthorized. +type unauthorized struct { + Err +} + +// Unauthorizedf returns an error which satisfies IsUnauthorized(). +func Unauthorizedf(format string, args ...interface{}) error { + return &unauthorized{wrap(nil, format, "", args...)} +} + +// NewUnauthorized returns an error which wraps err and satisfies +// IsUnauthorized(). +func NewUnauthorized(err error, msg string) error { + return &unauthorized{wrap(err, msg, "")} +} + +// IsUnauthorized reports whether err was created with Unauthorizedf() or +// NewUnauthorized(). +func IsUnauthorized(err error) bool { + err = Cause(err) + _, ok := err.(*unauthorized) + return ok +} + +// notImplemented represents an error when something is not +// implemented. +type notImplemented struct { + Err +} + +// NotImplementedf returns an error which satisfies IsNotImplemented(). +func NotImplementedf(format string, args ...interface{}) error { + return ¬Implemented{wrap(nil, format, " not implemented", args...)} +} + +// NewNotImplemented returns an error which wraps err and satisfies +// IsNotImplemented(). +func NewNotImplemented(err error, msg string) error { + return ¬Implemented{wrap(err, msg, "")} +} + +// IsNotImplemented reports whether err was created with +// NotImplementedf() or NewNotImplemented(). +func IsNotImplemented(err error) bool { + err = Cause(err) + _, ok := err.(*notImplemented) + return ok +} + +// alreadyExists represents and error when something already exists. +type alreadyExists struct { + Err +} + +// AlreadyExistsf returns an error which satisfies IsAlreadyExists(). +func AlreadyExistsf(format string, args ...interface{}) error { + return &alreadyExists{wrap(nil, format, " already exists", args...)} +} + +// NewAlreadyExists returns an error which wraps err and satisfies +// IsAlreadyExists(). +func NewAlreadyExists(err error, msg string) error { + return &alreadyExists{wrap(err, msg, "")} +} + +// IsAlreadyExists reports whether the error was created with +// AlreadyExistsf() or NewAlreadyExists(). +func IsAlreadyExists(err error) bool { + err = Cause(err) + _, ok := err.(*alreadyExists) + return ok +} + +// notSupported represents an error when something is not supported. +type notSupported struct { + Err +} + +// NotSupportedf returns an error which satisfies IsNotSupported(). +func NotSupportedf(format string, args ...interface{}) error { + return ¬Supported{wrap(nil, format, " not supported", args...)} +} + +// NewNotSupported returns an error which wraps err and satisfies +// IsNotSupported(). +func NewNotSupported(err error, msg string) error { + return ¬Supported{wrap(err, msg, "")} +} + +// IsNotSupported reports whether the error was created with +// NotSupportedf() or NewNotSupported(). +func IsNotSupported(err error) bool { + err = Cause(err) + _, ok := err.(*notSupported) + return ok +} + +// notValid represents an error when something is not valid. +type notValid struct { + Err +} + +// NotValidf returns an error which satisfies IsNotValid(). +func NotValidf(format string, args ...interface{}) error { + return ¬Valid{wrap(nil, format, " not valid", args...)} +} + +// NewNotValid returns an error which wraps err and satisfies IsNotValid(). +func NewNotValid(err error, msg string) error { + return ¬Valid{wrap(err, msg, "")} +} + +// IsNotValid reports whether the error was created with NotValidf() or +// NewNotValid(). +func IsNotValid(err error) bool { + err = Cause(err) + _, ok := err.(*notValid) + return ok +} + +// notProvisioned represents an error when something is not yet provisioned. +type notProvisioned struct { + Err +} + +// NotProvisionedf returns an error which satisfies IsNotProvisioned(). +func NotProvisionedf(format string, args ...interface{}) error { + return ¬Provisioned{wrap(nil, format, " not provisioned", args...)} +} + +// NewNotProvisioned returns an error which wraps err that satisfies +// IsNotProvisioned(). +func NewNotProvisioned(err error, msg string) error { + return ¬Provisioned{wrap(err, msg, "")} +} + +// IsNotProvisioned reports whether err was created with NotProvisionedf() or +// NewNotProvisioned(). +func IsNotProvisioned(err error) bool { + err = Cause(err) + _, ok := err.(*notProvisioned) + return ok +} + +// notAssigned represents an error when something is not yet assigned to +// something else. +type notAssigned struct { + Err +} + +// NotAssignedf returns an error which satisfies IsNotAssigned(). +func NotAssignedf(format string, args ...interface{}) error { + return ¬Assigned{wrap(nil, format, " not assigned", args...)} +} + +// NewNotAssigned returns an error which wraps err that satisfies +// IsNotAssigned(). +func NewNotAssigned(err error, msg string) error { + return ¬Assigned{wrap(err, msg, "")} +} + +// IsNotAssigned reports whether err was created with NotAssignedf() or +// NewNotAssigned(). +func IsNotAssigned(err error) bool { + err = Cause(err) + _, ok := err.(*notAssigned) + return ok +} + +// badRequest represents an error when a request has bad parameters. +type badRequest struct { + Err +} + +// BadRequestf returns an error which satisfies IsBadRequest(). +func BadRequestf(format string, args ...interface{}) error { + return &badRequest{wrap(nil, format, "", args...)} +} + +// NewBadRequest returns an error which wraps err that satisfies +// IsBadRequest(). +func NewBadRequest(err error, msg string) error { + return &badRequest{wrap(err, msg, "")} +} + +// IsBadRequest reports whether err was created with BadRequestf() or +// NewBadRequest(). +func IsBadRequest(err error) bool { + err = Cause(err) + _, ok := err.(*badRequest) + return ok +} + +// methodNotAllowed represents an error when an HTTP request +// is made with an inappropriate method. +type methodNotAllowed struct { + Err +} + +// MethodNotAllowedf returns an error which satisfies IsMethodNotAllowed(). +func MethodNotAllowedf(format string, args ...interface{}) error { + return &methodNotAllowed{wrap(nil, format, "", args...)} +} + +// NewMethodNotAllowed returns an error which wraps err that satisfies +// IsMethodNotAllowed(). +func NewMethodNotAllowed(err error, msg string) error { + return &methodNotAllowed{wrap(err, msg, "")} +} + +// IsMethodNotAllowed reports whether err was created with MethodNotAllowedf() or +// NewMethodNotAllowed(). +func IsMethodNotAllowed(err error) bool { + err = Cause(err) + _, ok := err.(*methodNotAllowed) + return ok +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/functions.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/functions.go new file mode 100644 index 0000000..994208d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/functions.go @@ -0,0 +1,330 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "fmt" + "strings" +) + +// New is a drop in replacement for the standard libary errors module that records +// the location that the error is created. +// +// For example: +// return errors.New("validation failed") +// +func New(message string) error { + err := &Err{message: message} + err.SetLocation(1) + return err +} + +// Errorf creates a new annotated error and records the location that the +// error is created. This should be a drop in replacement for fmt.Errorf. +// +// For example: +// return errors.Errorf("validation failed: %s", message) +// +func Errorf(format string, args ...interface{}) error { + err := &Err{message: fmt.Sprintf(format, args...)} + err.SetLocation(1) + return err +} + +// Trace adds the location of the Trace call to the stack. The Cause of the +// resulting error is the same as the error parameter. If the other error is +// nil, the result will be nil. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Trace(err) +// } +// +func Trace(other error) error { + if other == nil { + return nil + } + err := &Err{previous: other, cause: Cause(other)} + err.SetLocation(1) + return err +} + +// Annotate is used to add extra context to an existing error. The location of +// the Annotate call is recorded with the annotations. The file, line and +// function are also recorded. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Annotate(err, "failed to frombulate") +// } +// +func Annotate(other error, message string) error { + if other == nil { + return nil + } + err := &Err{ + previous: other, + cause: Cause(other), + message: message, + } + err.SetLocation(1) + return err +} + +// Annotatef is used to add extra context to an existing error. The location of +// the Annotate call is recorded with the annotations. The file, line and +// function are also recorded. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Annotatef(err, "failed to frombulate the %s", arg) +// } +// +func Annotatef(other error, format string, args ...interface{}) error { + if other == nil { + return nil + } + err := &Err{ + previous: other, + cause: Cause(other), + message: fmt.Sprintf(format, args...), + } + err.SetLocation(1) + return err +} + +// DeferredAnnotatef annotates the given error (when it is not nil) with the given +// format string and arguments (like fmt.Sprintf). If *err is nil, DeferredAnnotatef +// does nothing. This method is used in a defer statement in order to annotate any +// resulting error with the same message. +// +// For example: +// +// defer DeferredAnnotatef(&err, "failed to frombulate the %s", arg) +// +func DeferredAnnotatef(err *error, format string, args ...interface{}) { + if *err == nil { + return + } + newErr := &Err{ + message: fmt.Sprintf(format, args...), + cause: Cause(*err), + previous: *err, + } + newErr.SetLocation(1) + *err = newErr +} + +// Wrap changes the Cause of the error. The location of the Wrap call is also +// stored in the error stack. +// +// For example: +// if err := SomeFunc(); err != nil { +// newErr := &packageError{"more context", private_value} +// return errors.Wrap(err, newErr) +// } +// +func Wrap(other, newDescriptive error) error { + err := &Err{ + previous: other, + cause: newDescriptive, + } + err.SetLocation(1) + return err +} + +// Wrapf changes the Cause of the error, and adds an annotation. The location +// of the Wrap call is also stored in the error stack. +// +// For example: +// if err := SomeFunc(); err != nil { +// return errors.Wrapf(err, simpleErrorType, "invalid value %q", value) +// } +// +func Wrapf(other, newDescriptive error, format string, args ...interface{}) error { + err := &Err{ + message: fmt.Sprintf(format, args...), + previous: other, + cause: newDescriptive, + } + err.SetLocation(1) + return err +} + +// Mask masks the given error with the given format string and arguments (like +// fmt.Sprintf), returning a new error that maintains the error stack, but +// hides the underlying error type. The error string still contains the full +// annotations. If you want to hide the annotations, call Wrap. +func Maskf(other error, format string, args ...interface{}) error { + if other == nil { + return nil + } + err := &Err{ + message: fmt.Sprintf(format, args...), + previous: other, + } + err.SetLocation(1) + return err +} + +// Mask hides the underlying error type, and records the location of the masking. +func Mask(other error) error { + if other == nil { + return nil + } + err := &Err{ + previous: other, + } + err.SetLocation(1) + return err +} + +// Cause returns the cause of the given error. This will be either the +// original error, or the result of a Wrap or Mask call. +// +// Cause is the usual way to diagnose errors that may have been wrapped by +// the other errors functions. +func Cause(err error) error { + var diag error + if err, ok := err.(causer); ok { + diag = err.Cause() + } + if diag != nil { + return diag + } + return err +} + +type causer interface { + Cause() error +} + +type wrapper interface { + // Message returns the top level error message, + // not including the message from the Previous + // error. + Message() string + + // Underlying returns the Previous error, or nil + // if there is none. + Underlying() error +} + +type locationer interface { + Location() (string, int) +} + +var ( + _ wrapper = (*Err)(nil) + _ locationer = (*Err)(nil) + _ causer = (*Err)(nil) +) + +// Details returns information about the stack of errors wrapped by err, in +// the format: +// +// [{filename:99: error one} {otherfile:55: cause of error one}] +// +// This is a terse alternative to ErrorStack as it returns a single line. +func Details(err error) string { + if err == nil { + return "[]" + } + var s []byte + s = append(s, '[') + for { + s = append(s, '{') + if err, ok := err.(locationer); ok { + file, line := err.Location() + if file != "" { + s = append(s, fmt.Sprintf("%s:%d", file, line)...) + s = append(s, ": "...) + } + } + if cerr, ok := err.(wrapper); ok { + s = append(s, cerr.Message()...) + err = cerr.Underlying() + } else { + s = append(s, err.Error()...) + err = nil + } + s = append(s, '}') + if err == nil { + break + } + s = append(s, ' ') + } + s = append(s, ']') + return string(s) +} + +// ErrorStack returns a string representation of the annotated error. If the +// error passed as the parameter is not an annotated error, the result is +// simply the result of the Error() method on that error. +// +// If the error is an annotated error, a multi-line string is returned where +// each line represents one entry in the annotation stack. The full filename +// from the call stack is used in the output. +// +// first error +// github.com/juju/errors/annotation_test.go:193: +// github.com/juju/errors/annotation_test.go:194: annotation +// github.com/juju/errors/annotation_test.go:195: +// github.com/juju/errors/annotation_test.go:196: more context +// github.com/juju/errors/annotation_test.go:197: +func ErrorStack(err error) string { + return strings.Join(errorStack(err), "\n") +} + +func errorStack(err error) []string { + if err == nil { + return nil + } + + // We want the first error first + var lines []string + for { + var buff []byte + if err, ok := err.(locationer); ok { + file, line := err.Location() + // Strip off the leading GOPATH/src path elements. + file = trimGoPath(file) + if file != "" { + buff = append(buff, fmt.Sprintf("%s:%d", file, line)...) + buff = append(buff, ": "...) + } + } + if cerr, ok := err.(wrapper); ok { + message := cerr.Message() + buff = append(buff, message...) + // If there is a cause for this error, and it is different to the cause + // of the underlying error, then output the error string in the stack trace. + var cause error + if err1, ok := err.(causer); ok { + cause = err1.Cause() + } + err = cerr.Underlying() + if cause != nil && !sameError(Cause(err), cause) { + if message != "" { + buff = append(buff, ": "...) + } + buff = append(buff, cause.Error()...) + } + } else { + buff = append(buff, err.Error()...) + err = nil + } + lines = append(lines, string(buff)) + if err == nil { + break + } + } + // reverse the lines to get the original error, which was at the end of + // the list, back to the start. + var result []string + for i := len(lines); i > 0; i-- { + result = append(result, lines[i-1]) + } + return result +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/path.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/path.go new file mode 100644 index 0000000..a7b726a --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/juju/errors/path.go @@ -0,0 +1,38 @@ +// Copyright 2013, 2014 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package errors + +import ( + "runtime" + "strings" +) + +// prefixSize is used internally to trim the user specific path from the +// front of the returned filenames from the runtime call stack. +var prefixSize int + +// goPath is the deduced path based on the location of this file as compiled. +var goPath string + +func init() { + _, file, _, ok := runtime.Caller(0) + if file == "?" { + return + } + if ok { + // We know that the end of the file should be: + // github.com/juju/errors/path.go + size := len(file) + suffix := len("github.com/juju/errors/path.go") + goPath = file[:size-suffix] + prefixSize = len(goPath) + } +} + +func trimGoPath(filename string) string { + if strings.HasPrefix(filename, goPath) { + return filename[prefixSize:] + } + return filename +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/LICENSE new file mode 100644 index 0000000..6600f1c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/LICENSE @@ -0,0 +1,165 @@ +GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_unix.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_unix.go new file mode 100644 index 0000000..37f407d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_unix.go @@ -0,0 +1,18 @@ +// +build freebsd openbsd netbsd dragonfly darwin linux + +package log + +import ( + "log" + "os" + "syscall" +) + +func CrashLog(file string) { + f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.Println(err.Error()) + } else { + syscall.Dup2(int(f.Fd()), 2) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_win.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_win.go new file mode 100644 index 0000000..7d612ee --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/crash_win.go @@ -0,0 +1,37 @@ +// +build windows + +package log + +import ( + "log" + "os" + "syscall" +) + +var ( + kernel32 = syscall.MustLoadDLL("kernel32.dll") + procSetStdHandle = kernel32.MustFindProc("SetStdHandle") +) + +func setStdHandle(stdhandle int32, handle syscall.Handle) error { + r0, _, e1 := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdhandle), uintptr(handle), 0) + if r0 == 0 { + if e1 != 0 { + return error(e1) + } + return syscall.EINVAL + } + return nil +} + +func CrashLog(file string) { + f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.Println(err.Error()) + } else { + err = setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(f.Fd())) + if err != nil { + log.Println(err.Error()) + } + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/log.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/log.go new file mode 100644 index 0000000..896b393 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/ngaut/log/log.go @@ -0,0 +1,380 @@ +//high level log wrapper, so it can output different log based on level +package log + +import ( + "fmt" + "io" + "log" + "os" + "runtime" + "sync" + "time" +) + +const ( + Ldate = log.Ldate + Llongfile = log.Llongfile + Lmicroseconds = log.Lmicroseconds + Lshortfile = log.Lshortfile + LstdFlags = log.LstdFlags + Ltime = log.Ltime +) + +type ( + LogLevel int + LogType int +) + +const ( + LOG_FATAL = LogType(0x1) + LOG_ERROR = LogType(0x2) + LOG_WARNING = LogType(0x4) + LOG_INFO = LogType(0x8) + LOG_DEBUG = LogType(0x10) +) + +const ( + LOG_LEVEL_NONE = LogLevel(0x0) + LOG_LEVEL_FATAL = LOG_LEVEL_NONE | LogLevel(LOG_FATAL) + LOG_LEVEL_ERROR = LOG_LEVEL_FATAL | LogLevel(LOG_ERROR) + LOG_LEVEL_WARN = LOG_LEVEL_ERROR | LogLevel(LOG_WARNING) + LOG_LEVEL_INFO = LOG_LEVEL_WARN | LogLevel(LOG_INFO) + LOG_LEVEL_DEBUG = LOG_LEVEL_INFO | LogLevel(LOG_DEBUG) + LOG_LEVEL_ALL = LOG_LEVEL_DEBUG +) + +const FORMAT_TIME_DAY string = "20060102" +const FORMAT_TIME_HOUR string = "2006010215" + +var _log *logger = New() + +func init() { + SetFlags(Ldate | Ltime | Lshortfile) + SetHighlighting(runtime.GOOS != "windows") +} + +func Logger() *log.Logger { + return _log._log +} + +func SetLevel(level LogLevel) { + _log.SetLevel(level) +} +func GetLogLevel() LogLevel { + return _log.level +} + +func SetOutput(out io.Writer) { + _log.SetOutput(out) +} + +func SetOutputByName(path string) error { + return _log.SetOutputByName(path) +} + +func SetFlags(flags int) { + _log._log.SetFlags(flags) +} + +func Info(v ...interface{}) { + _log.Info(v...) +} + +func Infof(format string, v ...interface{}) { + _log.Infof(format, v...) +} + +func Debug(v ...interface{}) { + _log.Debug(v...) +} + +func Debugf(format string, v ...interface{}) { + _log.Debugf(format, v...) +} + +func Warn(v ...interface{}) { + _log.Warning(v...) +} + +func Warnf(format string, v ...interface{}) { + _log.Warningf(format, v...) +} + +func Warning(v ...interface{}) { + _log.Warning(v...) +} + +func Warningf(format string, v ...interface{}) { + _log.Warningf(format, v...) +} + +func Error(v ...interface{}) { + _log.Error(v...) +} + +func Errorf(format string, v ...interface{}) { + _log.Errorf(format, v...) +} + +func Fatal(v ...interface{}) { + _log.Fatal(v...) +} + +func Fatalf(format string, v ...interface{}) { + _log.Fatalf(format, v...) +} + +func SetLevelByString(level string) { + _log.SetLevelByString(level) +} + +func SetHighlighting(highlighting bool) { + _log.SetHighlighting(highlighting) +} + +func SetRotateByDay() { + _log.SetRotateByDay() +} + +func SetRotateByHour() { + _log.SetRotateByHour() +} + +type logger struct { + _log *log.Logger + level LogLevel + highlighting bool + + dailyRolling bool + hourRolling bool + + fileName string + logSuffix string + fd *os.File + + lock sync.Mutex +} + +func (l *logger) SetHighlighting(highlighting bool) { + l.highlighting = highlighting +} + +func (l *logger) SetLevel(level LogLevel) { + l.level = level +} + +func (l *logger) SetLevelByString(level string) { + l.level = StringToLogLevel(level) +} + +func (l *logger) SetRotateByDay() { + l.dailyRolling = true + l.logSuffix = genDayTime(time.Now()) +} + +func (l *logger) SetRotateByHour() { + l.hourRolling = true + l.logSuffix = genHourTime(time.Now()) +} + +func (l *logger) rotate() error { + l.lock.Lock() + defer l.lock.Unlock() + + var suffix string + if l.dailyRolling { + suffix = genDayTime(time.Now()) + } else if l.hourRolling { + suffix = genHourTime(time.Now()) + } else { + return nil + } + + // Notice: if suffix is not equal to l.LogSuffix, then rotate + if suffix != l.logSuffix { + err := l.doRotate(suffix) + if err != nil { + return err + } + } + + return nil +} + +func (l *logger) doRotate(suffix string) error { + // Notice: Not check error, is this ok? + l.fd.Close() + + lastFileName := l.fileName + "." + l.logSuffix + err := os.Rename(l.fileName, lastFileName) + if err != nil { + return err + } + + err = l.SetOutputByName(l.fileName) + if err != nil { + return err + } + + l.logSuffix = suffix + + return nil +} + +func (l *logger) SetOutput(out io.Writer) { + l._log = log.New(out, l._log.Prefix(), l._log.Flags()) +} + +func (l *logger) SetOutputByName(path string) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666) + if err != nil { + log.Fatal(err) + } + + l.SetOutput(f) + + l.fileName = path + l.fd = f + + return err +} + +func (l *logger) log(t LogType, v ...interface{}) { + if l.level|LogLevel(t) != l.level { + return + } + + err := l.rotate() + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return + } + + v1 := make([]interface{}, len(v)+2) + logStr, logColor := LogTypeToString(t) + if l.highlighting { + v1[0] = "\033" + logColor + "m[" + logStr + "]" + copy(v1[1:], v) + v1[len(v)+1] = "\033[0m" + } else { + v1[0] = "[" + logStr + "]" + copy(v1[1:], v) + v1[len(v)+1] = "" + } + + s := fmt.Sprintln(v1...) + l._log.Output(4, s) +} + +func (l *logger) logf(t LogType, format string, v ...interface{}) { + if l.level|LogLevel(t) != l.level { + return + } + + err := l.rotate() + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return + } + + logStr, logColor := LogTypeToString(t) + var s string + if l.highlighting { + s = "\033" + logColor + "m[" + logStr + "] " + fmt.Sprintf(format, v...) + "\033[0m" + } else { + s = "[" + logStr + "] " + fmt.Sprintf(format, v...) + } + l._log.Output(4, s) +} + +func (l *logger) Fatal(v ...interface{}) { + l.log(LOG_FATAL, v...) + os.Exit(-1) +} + +func (l *logger) Fatalf(format string, v ...interface{}) { + l.logf(LOG_FATAL, format, v...) + os.Exit(-1) +} + +func (l *logger) Error(v ...interface{}) { + l.log(LOG_ERROR, v...) +} + +func (l *logger) Errorf(format string, v ...interface{}) { + l.logf(LOG_ERROR, format, v...) +} + +func (l *logger) Warning(v ...interface{}) { + l.log(LOG_WARNING, v...) +} + +func (l *logger) Warningf(format string, v ...interface{}) { + l.logf(LOG_WARNING, format, v...) +} + +func (l *logger) Debug(v ...interface{}) { + l.log(LOG_DEBUG, v...) +} + +func (l *logger) Debugf(format string, v ...interface{}) { + l.logf(LOG_DEBUG, format, v...) +} + +func (l *logger) Info(v ...interface{}) { + l.log(LOG_INFO, v...) +} + +func (l *logger) Infof(format string, v ...interface{}) { + l.logf(LOG_INFO, format, v...) +} + +func StringToLogLevel(level string) LogLevel { + switch level { + case "fatal": + return LOG_LEVEL_FATAL + case "error": + return LOG_LEVEL_ERROR + case "warn": + return LOG_LEVEL_WARN + case "warning": + return LOG_LEVEL_WARN + case "debug": + return LOG_LEVEL_DEBUG + case "info": + return LOG_LEVEL_INFO + } + return LOG_LEVEL_ALL +} + +func LogTypeToString(t LogType) (string, string) { + switch t { + case LOG_FATAL: + return "fatal", "[0;31" + case LOG_ERROR: + return "error", "[0;31" + case LOG_WARNING: + return "warning", "[0;33" + case LOG_DEBUG: + return "debug", "[0;36" + case LOG_INFO: + return "info", "[0;37" + } + return "unknown", "[0;37" +} + +func genDayTime(t time.Time) string { + return t.Format(FORMAT_TIME_DAY) +} + +func genHourTime(t time.Time) string { + return t.Format(FORMAT_TIME_HOUR) +} + +func New() *logger { + return Newlogger(os.Stderr, "") +} + +func Newlogger(w io.Writer, prefix string) *logger { + return &logger{_log: log.New(w, prefix, LstdFlags), level: LOG_LEVEL_ALL, highlighting: true} +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/benchmark.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/benchmark.go new file mode 100644 index 0000000..46ea9dc --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/benchmark.go @@ -0,0 +1,187 @@ +// Copyright (c) 2012 The Go Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package check + +import ( + "fmt" + "runtime" + "time" +) + +var memStats runtime.MemStats + +// testingB is a type passed to Benchmark functions to manage benchmark +// timing and to specify the number of iterations to run. +type timer struct { + start time.Time // Time test or benchmark started + duration time.Duration + N int + bytes int64 + timerOn bool + benchTime time.Duration + // The initial states of memStats.Mallocs and memStats.TotalAlloc. + startAllocs uint64 + startBytes uint64 + // The net total of this test after being run. + netAllocs uint64 + netBytes uint64 +} + +// StartTimer starts timing a test. This function is called automatically +// before a benchmark starts, but it can also used to resume timing after +// a call to StopTimer. +func (c *C) StartTimer() { + if !c.timerOn { + c.start = time.Now() + c.timerOn = true + + runtime.ReadMemStats(&memStats) + c.startAllocs = memStats.Mallocs + c.startBytes = memStats.TotalAlloc + } +} + +// StopTimer stops timing a test. This can be used to pause the timer +// while performing complex initialization that you don't +// want to measure. +func (c *C) StopTimer() { + if c.timerOn { + c.duration += time.Now().Sub(c.start) + c.timerOn = false + runtime.ReadMemStats(&memStats) + c.netAllocs += memStats.Mallocs - c.startAllocs + c.netBytes += memStats.TotalAlloc - c.startBytes + } +} + +// ResetTimer sets the elapsed benchmark time to zero. +// It does not affect whether the timer is running. +func (c *C) ResetTimer() { + if c.timerOn { + c.start = time.Now() + runtime.ReadMemStats(&memStats) + c.startAllocs = memStats.Mallocs + c.startBytes = memStats.TotalAlloc + } + c.duration = 0 + c.netAllocs = 0 + c.netBytes = 0 +} + +// SetBytes informs the number of bytes that the benchmark processes +// on each iteration. If this is called in a benchmark it will also +// report MB/s. +func (c *C) SetBytes(n int64) { + c.bytes = n +} + +func (c *C) nsPerOp() int64 { + if c.N <= 0 { + return 0 + } + return c.duration.Nanoseconds() / int64(c.N) +} + +func (c *C) mbPerSec() float64 { + if c.bytes <= 0 || c.duration <= 0 || c.N <= 0 { + return 0 + } + return (float64(c.bytes) * float64(c.N) / 1e6) / c.duration.Seconds() +} + +func (c *C) timerString() string { + if c.N <= 0 { + return fmt.Sprintf("%3.3fs", float64(c.duration.Nanoseconds())/1e9) + } + mbs := c.mbPerSec() + mb := "" + if mbs != 0 { + mb = fmt.Sprintf("\t%7.2f MB/s", mbs) + } + nsop := c.nsPerOp() + ns := fmt.Sprintf("%10d ns/op", nsop) + if c.N > 0 && nsop < 100 { + // The format specifiers here make sure that + // the ones digits line up for all three possible formats. + if nsop < 10 { + ns = fmt.Sprintf("%13.2f ns/op", float64(c.duration.Nanoseconds())/float64(c.N)) + } else { + ns = fmt.Sprintf("%12.1f ns/op", float64(c.duration.Nanoseconds())/float64(c.N)) + } + } + memStats := "" + if c.benchMem { + allocedBytes := fmt.Sprintf("%8d B/op", int64(c.netBytes)/int64(c.N)) + allocs := fmt.Sprintf("%8d allocs/op", int64(c.netAllocs)/int64(c.N)) + memStats = fmt.Sprintf("\t%s\t%s", allocedBytes, allocs) + } + return fmt.Sprintf("%8d\t%s%s%s", c.N, ns, mb, memStats) +} + +func min(x, y int) int { + if x > y { + return y + } + return x +} + +func max(x, y int) int { + if x < y { + return y + } + return x +} + +// roundDown10 rounds a number down to the nearest power of 10. +func roundDown10(n int) int { + var tens = 0 + // tens = floor(log_10(n)) + for n > 10 { + n = n / 10 + tens++ + } + // result = 10^tens + result := 1 + for i := 0; i < tens; i++ { + result *= 10 + } + return result +} + +// roundUp rounds x up to a number of the form [1eX, 2eX, 5eX]. +func roundUp(n int) int { + base := roundDown10(n) + if n < (2 * base) { + return 2 * base + } + if n < (5 * base) { + return 5 * base + } + return 10 * base +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/check.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/check.go new file mode 100644 index 0000000..c99392a --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/check.go @@ -0,0 +1,954 @@ +// Package check is a rich testing extension for Go's testing package. +// +// For details about the project, see: +// +// http://labix.org/gocheck +// +package check + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "os" + "path" + "path/filepath" + "reflect" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ----------------------------------------------------------------------- +// Internal type which deals with suite method calling. + +const ( + fixtureKd = iota + testKd +) + +type funcKind int + +const ( + succeededSt = iota + failedSt + skippedSt + panickedSt + fixturePanickedSt + missedSt +) + +type funcStatus uint32 + +// A method value can't reach its own Method structure. +type methodType struct { + reflect.Value + Info reflect.Method +} + +func newMethod(receiver reflect.Value, i int) *methodType { + return &methodType{receiver.Method(i), receiver.Type().Method(i)} +} + +func (method *methodType) PC() uintptr { + return method.Info.Func.Pointer() +} + +func (method *methodType) suiteName() string { + t := method.Info.Type.In(0) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.Name() +} + +func (method *methodType) String() string { + return method.suiteName() + "." + method.Info.Name +} + +func (method *methodType) matches(re *regexp.Regexp) bool { + return (re.MatchString(method.Info.Name) || + re.MatchString(method.suiteName()) || + re.MatchString(method.String())) +} + +type C struct { + method *methodType + kind funcKind + testName string + _status funcStatus + logb *logger + logw io.Writer + done chan *C + reason string + mustFail bool + tempDir *tempDir + benchMem bool + startTime time.Time + timer +} + +func (c *C) status() funcStatus { + return funcStatus(atomic.LoadUint32((*uint32)(&c._status))) +} + +func (c *C) setStatus(s funcStatus) { + atomic.StoreUint32((*uint32)(&c._status), uint32(s)) +} + +func (c *C) stopNow() { + runtime.Goexit() +} + +// logger is a concurrency safe byte.Buffer +type logger struct { + sync.Mutex + writer bytes.Buffer +} + +func (l *logger) Write(buf []byte) (int, error) { + l.Lock() + defer l.Unlock() + return l.writer.Write(buf) +} + +func (l *logger) WriteTo(w io.Writer) (int64, error) { + l.Lock() + defer l.Unlock() + return l.writer.WriteTo(w) +} + +func (l *logger) String() string { + l.Lock() + defer l.Unlock() + return l.writer.String() +} + +// ----------------------------------------------------------------------- +// Handling of temporary files and directories. + +type tempDir struct { + sync.Mutex + path string + counter int +} + +func (td *tempDir) newPath() string { + td.Lock() + defer td.Unlock() + if td.path == "" { + var err error + for i := 0; i != 100; i++ { + path := fmt.Sprintf("%s%ccheck-%d", os.TempDir(), os.PathSeparator, rand.Int()) + if err = os.Mkdir(path, 0700); err == nil { + td.path = path + break + } + } + if td.path == "" { + panic("Couldn't create temporary directory: " + err.Error()) + } + } + result := filepath.Join(td.path, strconv.Itoa(td.counter)) + td.counter += 1 + return result +} + +func (td *tempDir) removeAll() { + td.Lock() + defer td.Unlock() + if td.path != "" { + err := os.RemoveAll(td.path) + if err != nil { + fmt.Fprintf(os.Stderr, "WARNING: Error cleaning up temporaries: "+err.Error()) + } + } +} + +// Create a new temporary directory which is automatically removed after +// the suite finishes running. +func (c *C) MkDir() string { + path := c.tempDir.newPath() + if err := os.Mkdir(path, 0700); err != nil { + panic(fmt.Sprintf("Couldn't create temporary directory %s: %s", path, err.Error())) + } + return path +} + +// ----------------------------------------------------------------------- +// Low-level logging functions. + +func (c *C) log(args ...interface{}) { + c.writeLog([]byte(fmt.Sprint(args...) + "\n")) +} + +func (c *C) logf(format string, args ...interface{}) { + c.writeLog([]byte(fmt.Sprintf(format+"\n", args...))) +} + +func (c *C) logNewLine() { + c.writeLog([]byte{'\n'}) +} + +func (c *C) writeLog(buf []byte) { + c.logb.Write(buf) + if c.logw != nil { + c.logw.Write(buf) + } +} + +func hasStringOrError(x interface{}) (ok bool) { + _, ok = x.(fmt.Stringer) + if ok { + return + } + _, ok = x.(error) + return +} + +func (c *C) logValue(label string, value interface{}) { + if label == "" { + if hasStringOrError(value) { + c.logf("... %#v (%q)", value, value) + } else { + c.logf("... %#v", value) + } + } else if value == nil { + c.logf("... %s = nil", label) + } else { + if hasStringOrError(value) { + fv := fmt.Sprintf("%#v", value) + qv := fmt.Sprintf("%q", value) + if fv != qv { + c.logf("... %s %s = %s (%s)", label, reflect.TypeOf(value), fv, qv) + return + } + } + if s, ok := value.(string); ok && isMultiLine(s) { + c.logf(`... %s %s = "" +`, label, reflect.TypeOf(value)) + c.logMultiLine(s) + } else { + c.logf("... %s %s = %#v", label, reflect.TypeOf(value), value) + } + } +} + +func (c *C) logMultiLine(s string) { + b := make([]byte, 0, len(s)*2) + i := 0 + n := len(s) + for i < n { + j := i + 1 + for j < n && s[j-1] != '\n' { + j++ + } + b = append(b, "... "...) + b = strconv.AppendQuote(b, s[i:j]) + if j < n { + b = append(b, " +"...) + } + b = append(b, '\n') + i = j + } + c.writeLog(b) +} + +func isMultiLine(s string) bool { + for i := 0; i+1 < len(s); i++ { + if s[i] == '\n' { + return true + } + } + return false +} + +func (c *C) logString(issue string) { + c.log("... ", issue) +} + +func (c *C) logCaller(skip int) { + // This is a bit heavier than it ought to be. + skip += 1 // Our own frame. + pc, callerFile, callerLine, ok := runtime.Caller(skip) + if !ok { + return + } + var testFile string + var testLine int + testFunc := runtime.FuncForPC(c.method.PC()) + if runtime.FuncForPC(pc) != testFunc { + for { + skip += 1 + if pc, file, line, ok := runtime.Caller(skip); ok { + // Note that the test line may be different on + // distinct calls for the same test. Showing + // the "internal" line is helpful when debugging. + if runtime.FuncForPC(pc) == testFunc { + testFile, testLine = file, line + break + } + } else { + break + } + } + } + if testFile != "" && (testFile != callerFile || testLine != callerLine) { + c.logCode(testFile, testLine) + } + c.logCode(callerFile, callerLine) +} + +func (c *C) logCode(path string, line int) { + c.logf("%s:%d:", nicePath(path), line) + code, err := printLine(path, line) + if code == "" { + code = "..." // XXX Open the file and take the raw line. + if err != nil { + code += err.Error() + } + } + c.log(indent(code, " ")) +} + +var valueGo = filepath.Join("reflect", "value.go") +var asmGo = filepath.Join("runtime", "asm_") + +func (c *C) logPanic(skip int, value interface{}) { + skip++ // Our own frame. + initialSkip := skip + for ; ; skip++ { + if pc, file, line, ok := runtime.Caller(skip); ok { + if skip == initialSkip { + c.logf("... Panic: %s (PC=0x%X)\n", value, pc) + } + name := niceFuncName(pc) + path := nicePath(file) + if strings.Contains(path, "/gopkg.in/check.v") { + continue + } + if name == "Value.call" && strings.HasSuffix(path, valueGo) { + continue + } + if (name == "call16" || name == "call32") && strings.Contains(path, asmGo) { + continue + } + c.logf("%s:%d\n in %s", nicePath(file), line, name) + } else { + break + } + } +} + +func (c *C) logSoftPanic(issue string) { + c.log("... Panic: ", issue) +} + +func (c *C) logArgPanic(method *methodType, expectedType string) { + c.logf("... Panic: %s argument should be %s", + niceFuncName(method.PC()), expectedType) +} + +// ----------------------------------------------------------------------- +// Some simple formatting helpers. + +var initWD, initWDErr = os.Getwd() + +func init() { + if initWDErr == nil { + initWD = strings.Replace(initWD, "\\", "/", -1) + "/" + } +} + +func nicePath(path string) string { + if initWDErr == nil { + if strings.HasPrefix(path, initWD) { + return path[len(initWD):] + } + } + return path +} + +func niceFuncPath(pc uintptr) string { + function := runtime.FuncForPC(pc) + if function != nil { + filename, line := function.FileLine(pc) + return fmt.Sprintf("%s:%d", nicePath(filename), line) + } + return "" +} + +func niceFuncName(pc uintptr) string { + function := runtime.FuncForPC(pc) + if function != nil { + name := path.Base(function.Name()) + if i := strings.Index(name, "."); i > 0 { + name = name[i+1:] + } + if strings.HasPrefix(name, "(*") { + if i := strings.Index(name, ")"); i > 0 { + name = name[2:i] + name[i+1:] + } + } + if i := strings.LastIndex(name, ".*"); i != -1 { + name = name[:i] + "." + name[i+2:] + } + if i := strings.LastIndex(name, "·"); i != -1 { + name = name[:i] + "." + name[i+2:] + } + return name + } + return "" +} + +// ----------------------------------------------------------------------- +// Result tracker to aggregate call results. + +type Result struct { + Succeeded int + Failed int + Skipped int + Panicked int + FixturePanicked int + ExpectedFailures int + Missed int // Not even tried to run, related to a panic in the fixture. + RunError error // Houston, we've got a problem. + WorkDir string // If KeepWorkDir is true +} + +type resultTracker struct { + result Result + _lastWasProblem bool + _waiting int + _missed int + _expectChan chan *C + _doneChan chan *C + _stopChan chan bool +} + +func newResultTracker() *resultTracker { + return &resultTracker{_expectChan: make(chan *C), // Synchronous + _doneChan: make(chan *C, 32), // Asynchronous + _stopChan: make(chan bool)} // Synchronous +} + +func (tracker *resultTracker) start() { + go tracker._loopRoutine() +} + +func (tracker *resultTracker) waitAndStop() { + <-tracker._stopChan +} + +func (tracker *resultTracker) expectCall(c *C) { + tracker._expectChan <- c +} + +func (tracker *resultTracker) callDone(c *C) { + tracker._doneChan <- c +} + +func (tracker *resultTracker) _loopRoutine() { + for { + var c *C + if tracker._waiting > 0 { + // Calls still running. Can't stop. + select { + // XXX Reindent this (not now to make diff clear) + case c = <-tracker._expectChan: + tracker._waiting += 1 + case c = <-tracker._doneChan: + tracker._waiting -= 1 + switch c.status() { + case succeededSt: + if c.kind == testKd { + if c.mustFail { + tracker.result.ExpectedFailures++ + } else { + tracker.result.Succeeded++ + } + } + case failedSt: + tracker.result.Failed++ + case panickedSt: + if c.kind == fixtureKd { + tracker.result.FixturePanicked++ + } else { + tracker.result.Panicked++ + } + case fixturePanickedSt: + // Track it as missed, since the panic + // was on the fixture, not on the test. + tracker.result.Missed++ + case missedSt: + tracker.result.Missed++ + case skippedSt: + if c.kind == testKd { + tracker.result.Skipped++ + } + } + } + } else { + // No calls. Can stop, but no done calls here. + select { + case tracker._stopChan <- true: + return + case c = <-tracker._expectChan: + tracker._waiting += 1 + case c = <-tracker._doneChan: + panic("Tracker got an unexpected done call.") + } + } + } +} + +// ----------------------------------------------------------------------- +// The underlying suite runner. + +type suiteRunner struct { + suite interface{} + setUpSuite, tearDownSuite *methodType + setUpTest, tearDownTest *methodType + tests []*methodType + tracker *resultTracker + tempDir *tempDir + keepDir bool + output *outputWriter + reportedProblemLast bool + benchTime time.Duration + benchMem bool +} + +type RunConf struct { + Output io.Writer + Stream bool + Verbose bool + Filter string + Benchmark bool + BenchmarkTime time.Duration // Defaults to 1 second + BenchmarkMem bool + KeepWorkDir bool +} + +// Create a new suiteRunner able to run all methods in the given suite. +func newSuiteRunner(suite interface{}, runConf *RunConf) *suiteRunner { + var conf RunConf + if runConf != nil { + conf = *runConf + } + if conf.Output == nil { + conf.Output = os.Stdout + } + if conf.Benchmark { + conf.Verbose = true + } + + suiteType := reflect.TypeOf(suite) + suiteNumMethods := suiteType.NumMethod() + suiteValue := reflect.ValueOf(suite) + + runner := &suiteRunner{ + suite: suite, + output: newOutputWriter(conf.Output, conf.Stream, conf.Verbose), + tracker: newResultTracker(), + benchTime: conf.BenchmarkTime, + benchMem: conf.BenchmarkMem, + tempDir: &tempDir{}, + keepDir: conf.KeepWorkDir, + tests: make([]*methodType, 0, suiteNumMethods), + } + if runner.benchTime == 0 { + runner.benchTime = 1 * time.Second + } + + var filterRegexp *regexp.Regexp + if conf.Filter != "" { + if regexp, err := regexp.Compile(conf.Filter); err != nil { + msg := "Bad filter expression: " + err.Error() + runner.tracker.result.RunError = errors.New(msg) + return runner + } else { + filterRegexp = regexp + } + } + + for i := 0; i != suiteNumMethods; i++ { + method := newMethod(suiteValue, i) + switch method.Info.Name { + case "SetUpSuite": + runner.setUpSuite = method + case "TearDownSuite": + runner.tearDownSuite = method + case "SetUpTest": + runner.setUpTest = method + case "TearDownTest": + runner.tearDownTest = method + default: + prefix := "Test" + if conf.Benchmark { + prefix = "Benchmark" + } + if !strings.HasPrefix(method.Info.Name, prefix) { + continue + } + if filterRegexp == nil || method.matches(filterRegexp) { + runner.tests = append(runner.tests, method) + } + } + } + return runner +} + +// Run all methods in the given suite. +func (runner *suiteRunner) run() *Result { + if runner.tracker.result.RunError == nil && len(runner.tests) > 0 { + runner.tracker.start() + if runner.checkFixtureArgs() { + c := runner.runFixture(runner.setUpSuite, "", nil) + if c == nil || c.status() == succeededSt { + for i := 0; i != len(runner.tests); i++ { + c := runner.runTest(runner.tests[i]) + if c.status() == fixturePanickedSt { + runner.skipTests(missedSt, runner.tests[i+1:]) + break + } + } + } else if c != nil && c.status() == skippedSt { + runner.skipTests(skippedSt, runner.tests) + } else { + runner.skipTests(missedSt, runner.tests) + } + runner.runFixture(runner.tearDownSuite, "", nil) + } else { + runner.skipTests(missedSt, runner.tests) + } + runner.tracker.waitAndStop() + if runner.keepDir { + runner.tracker.result.WorkDir = runner.tempDir.path + } else { + runner.tempDir.removeAll() + } + } + return &runner.tracker.result +} + +// Create a call object with the given suite method, and fork a +// goroutine with the provided dispatcher for running it. +func (runner *suiteRunner) forkCall(method *methodType, kind funcKind, testName string, logb *logger, dispatcher func(c *C)) *C { + var logw io.Writer + if runner.output.Stream { + logw = runner.output + } + if logb == nil { + logb = new(logger) + } + c := &C{ + method: method, + kind: kind, + testName: testName, + logb: logb, + logw: logw, + tempDir: runner.tempDir, + done: make(chan *C, 1), + timer: timer{benchTime: runner.benchTime}, + startTime: time.Now(), + benchMem: runner.benchMem, + } + runner.tracker.expectCall(c) + go (func() { + runner.reportCallStarted(c) + defer runner.callDone(c) + dispatcher(c) + })() + return c +} + +// Same as forkCall(), but wait for call to finish before returning. +func (runner *suiteRunner) runFunc(method *methodType, kind funcKind, testName string, logb *logger, dispatcher func(c *C)) *C { + c := runner.forkCall(method, kind, testName, logb, dispatcher) + <-c.done + return c +} + +// Handle a finished call. If there were any panics, update the call status +// accordingly. Then, mark the call as done and report to the tracker. +func (runner *suiteRunner) callDone(c *C) { + value := recover() + if value != nil { + switch v := value.(type) { + case *fixturePanic: + if v.status == skippedSt { + c.setStatus(skippedSt) + } else { + c.logSoftPanic("Fixture has panicked (see related PANIC)") + c.setStatus(fixturePanickedSt) + } + default: + c.logPanic(1, value) + c.setStatus(panickedSt) + } + } + if c.mustFail { + switch c.status() { + case failedSt: + c.setStatus(succeededSt) + case succeededSt: + c.setStatus(failedSt) + c.logString("Error: Test succeeded, but was expected to fail") + c.logString("Reason: " + c.reason) + } + } + + runner.reportCallDone(c) + c.done <- c +} + +// Runs a fixture call synchronously. The fixture will still be run in a +// goroutine like all suite methods, but this method will not return +// while the fixture goroutine is not done, because the fixture must be +// run in a desired order. +func (runner *suiteRunner) runFixture(method *methodType, testName string, logb *logger) *C { + if method != nil { + c := runner.runFunc(method, fixtureKd, testName, logb, func(c *C) { + c.ResetTimer() + c.StartTimer() + defer c.StopTimer() + c.method.Call([]reflect.Value{reflect.ValueOf(c)}) + }) + return c + } + return nil +} + +// Run the fixture method with runFixture(), but panic with a fixturePanic{} +// in case the fixture method panics. This makes it easier to track the +// fixture panic together with other call panics within forkTest(). +func (runner *suiteRunner) runFixtureWithPanic(method *methodType, testName string, logb *logger, skipped *bool) *C { + if skipped != nil && *skipped { + return nil + } + c := runner.runFixture(method, testName, logb) + if c != nil && c.status() != succeededSt { + if skipped != nil { + *skipped = c.status() == skippedSt + } + panic(&fixturePanic{c.status(), method}) + } + return c +} + +type fixturePanic struct { + status funcStatus + method *methodType +} + +// Run the suite test method, together with the test-specific fixture, +// asynchronously. +func (runner *suiteRunner) forkTest(method *methodType) *C { + testName := method.String() + return runner.forkCall(method, testKd, testName, nil, func(c *C) { + var skipped bool + defer runner.runFixtureWithPanic(runner.tearDownTest, testName, nil, &skipped) + defer c.StopTimer() + benchN := 1 + for { + runner.runFixtureWithPanic(runner.setUpTest, testName, c.logb, &skipped) + mt := c.method.Type() + if mt.NumIn() != 1 || mt.In(0) != reflect.TypeOf(c) { + // Rather than a plain panic, provide a more helpful message when + // the argument type is incorrect. + c.setStatus(panickedSt) + c.logArgPanic(c.method, "*check.C") + return + } + if strings.HasPrefix(c.method.Info.Name, "Test") { + c.ResetTimer() + c.StartTimer() + c.method.Call([]reflect.Value{reflect.ValueOf(c)}) + return + } + if !strings.HasPrefix(c.method.Info.Name, "Benchmark") { + panic("unexpected method prefix: " + c.method.Info.Name) + } + + runtime.GC() + c.N = benchN + c.ResetTimer() + c.StartTimer() + c.method.Call([]reflect.Value{reflect.ValueOf(c)}) + c.StopTimer() + if c.status() != succeededSt || c.duration >= c.benchTime || benchN >= 1e9 { + return + } + perOpN := int(1e9) + if c.nsPerOp() != 0 { + perOpN = int(c.benchTime.Nanoseconds() / c.nsPerOp()) + } + + // Logic taken from the stock testing package: + // - Run more iterations than we think we'll need for a second (1.5x). + // - Don't grow too fast in case we had timing errors previously. + // - Be sure to run at least one more than last time. + benchN = max(min(perOpN+perOpN/2, 100*benchN), benchN+1) + benchN = roundUp(benchN) + + skipped = true // Don't run the deferred one if this panics. + runner.runFixtureWithPanic(runner.tearDownTest, testName, nil, nil) + skipped = false + } + }) +} + +// Same as forkTest(), but wait for the test to finish before returning. +func (runner *suiteRunner) runTest(method *methodType) *C { + c := runner.forkTest(method) + <-c.done + return c +} + +// Helper to mark tests as skipped or missed. A bit heavy for what +// it does, but it enables homogeneous handling of tracking, including +// nice verbose output. +func (runner *suiteRunner) skipTests(status funcStatus, methods []*methodType) { + for _, method := range methods { + runner.runFunc(method, testKd, "", nil, func(c *C) { + c.setStatus(status) + }) + } +} + +// Verify if the fixture arguments are *check.C. In case of errors, +// log the error as a panic in the fixture method call, and return false. +func (runner *suiteRunner) checkFixtureArgs() bool { + succeeded := true + argType := reflect.TypeOf(&C{}) + for _, method := range []*methodType{runner.setUpSuite, runner.tearDownSuite, runner.setUpTest, runner.tearDownTest} { + if method != nil { + mt := method.Type() + if mt.NumIn() != 1 || mt.In(0) != argType { + succeeded = false + runner.runFunc(method, fixtureKd, "", nil, func(c *C) { + c.logArgPanic(method, "*check.C") + c.setStatus(panickedSt) + }) + } + } + } + return succeeded +} + +func (runner *suiteRunner) reportCallStarted(c *C) { + runner.output.WriteCallStarted("START", c) +} + +func (runner *suiteRunner) reportCallDone(c *C) { + runner.tracker.callDone(c) + switch c.status() { + case succeededSt: + if c.mustFail { + runner.output.WriteCallSuccess("FAIL EXPECTED", c) + } else { + runner.output.WriteCallSuccess("PASS", c) + } + case skippedSt: + runner.output.WriteCallSuccess("SKIP", c) + case failedSt: + runner.output.WriteCallProblem("FAIL", c) + case panickedSt: + runner.output.WriteCallProblem("PANIC", c) + case fixturePanickedSt: + // That's a testKd call reporting that its fixture + // has panicked. The fixture call which caused the + // panic itself was tracked above. We'll report to + // aid debugging. + runner.output.WriteCallProblem("PANIC", c) + case missedSt: + runner.output.WriteCallSuccess("MISS", c) + } +} + +// ----------------------------------------------------------------------- +// Output writer manages atomic output writing according to settings. + +type outputWriter struct { + m sync.Mutex + writer io.Writer + wroteCallProblemLast bool + Stream bool + Verbose bool +} + +func newOutputWriter(writer io.Writer, stream, verbose bool) *outputWriter { + return &outputWriter{writer: writer, Stream: stream, Verbose: verbose} +} + +func (ow *outputWriter) Write(content []byte) (n int, err error) { + ow.m.Lock() + n, err = ow.writer.Write(content) + ow.m.Unlock() + return +} + +func (ow *outputWriter) WriteCallStarted(label string, c *C) { + if ow.Stream { + header := renderCallHeader(label, c, "", "\n") + ow.m.Lock() + ow.writer.Write([]byte(header)) + ow.m.Unlock() + } +} + +func (ow *outputWriter) WriteCallProblem(label string, c *C) { + var prefix string + if !ow.Stream { + prefix = "\n-----------------------------------" + + "-----------------------------------\n" + } + header := renderCallHeader(label, c, prefix, "\n\n") + ow.m.Lock() + ow.wroteCallProblemLast = true + ow.writer.Write([]byte(header)) + if !ow.Stream { + c.logb.WriteTo(ow.writer) + } + ow.m.Unlock() +} + +func (ow *outputWriter) WriteCallSuccess(label string, c *C) { + if ow.Stream || (ow.Verbose && c.kind == testKd) { + // TODO Use a buffer here. + var suffix string + if c.reason != "" { + suffix = " (" + c.reason + ")" + } + if c.status() == succeededSt { + suffix += "\t" + c.timerString() + } + suffix += "\n" + if ow.Stream { + suffix += "\n" + } + header := renderCallHeader(label, c, "", suffix) + ow.m.Lock() + // Resist temptation of using line as prefix above due to race. + if !ow.Stream && ow.wroteCallProblemLast { + header = "\n-----------------------------------" + + "-----------------------------------\n" + + header + } + ow.wroteCallProblemLast = false + ow.writer.Write([]byte(header)) + ow.m.Unlock() + } +} + +func renderCallHeader(label string, c *C, prefix, suffix string) string { + pc := c.method.PC() + return fmt.Sprintf("%s%s: %s: %s%s", prefix, label, niceFuncPath(pc), + niceFuncName(pc), suffix) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers.go new file mode 100644 index 0000000..bac3387 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers.go @@ -0,0 +1,458 @@ +package check + +import ( + "fmt" + "reflect" + "regexp" +) + +// ----------------------------------------------------------------------- +// CommentInterface and Commentf helper, to attach extra information to checks. + +type comment struct { + format string + args []interface{} +} + +// Commentf returns an infomational value to use with Assert or Check calls. +// If the checker test fails, the provided arguments will be passed to +// fmt.Sprintf, and will be presented next to the logged failure. +// +// For example: +// +// c.Assert(v, Equals, 42, Commentf("Iteration #%d failed.", i)) +// +// Note that if the comment is constant, a better option is to +// simply use a normal comment right above or next to the line, as +// it will also get printed with any errors: +// +// c.Assert(l, Equals, 8192) // Ensure buffer size is correct (bug #123) +// +func Commentf(format string, args ...interface{}) CommentInterface { + return &comment{format, args} +} + +// CommentInterface must be implemented by types that attach extra +// information to failed checks. See the Commentf function for details. +type CommentInterface interface { + CheckCommentString() string +} + +func (c *comment) CheckCommentString() string { + return fmt.Sprintf(c.format, c.args...) +} + +// ----------------------------------------------------------------------- +// The Checker interface. + +// The Checker interface must be provided by checkers used with +// the Assert and Check verification methods. +type Checker interface { + Info() *CheckerInfo + Check(params []interface{}, names []string) (result bool, error string) +} + +// See the Checker interface. +type CheckerInfo struct { + Name string + Params []string +} + +func (info *CheckerInfo) Info() *CheckerInfo { + return info +} + +// ----------------------------------------------------------------------- +// Not checker logic inverter. + +// The Not checker inverts the logic of the provided checker. The +// resulting checker will succeed where the original one failed, and +// vice-versa. +// +// For example: +// +// c.Assert(a, Not(Equals), b) +// +func Not(checker Checker) Checker { + return ¬Checker{checker} +} + +type notChecker struct { + sub Checker +} + +func (checker *notChecker) Info() *CheckerInfo { + info := *checker.sub.Info() + info.Name = "Not(" + info.Name + ")" + return &info +} + +func (checker *notChecker) Check(params []interface{}, names []string) (result bool, error string) { + result, error = checker.sub.Check(params, names) + result = !result + return +} + +// ----------------------------------------------------------------------- +// IsNil checker. + +type isNilChecker struct { + *CheckerInfo +} + +// The IsNil checker tests whether the obtained value is nil. +// +// For example: +// +// c.Assert(err, IsNil) +// +var IsNil Checker = &isNilChecker{ + &CheckerInfo{Name: "IsNil", Params: []string{"value"}}, +} + +func (checker *isNilChecker) Check(params []interface{}, names []string) (result bool, error string) { + return isNil(params[0]), "" +} + +func isNil(obtained interface{}) (result bool) { + if obtained == nil { + result = true + } else { + switch v := reflect.ValueOf(obtained); v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return v.IsNil() + } + } + return +} + +// ----------------------------------------------------------------------- +// NotNil checker. Alias for Not(IsNil), since it's so common. + +type notNilChecker struct { + *CheckerInfo +} + +// The NotNil checker verifies that the obtained value is not nil. +// +// For example: +// +// c.Assert(iface, NotNil) +// +// This is an alias for Not(IsNil), made available since it's a +// fairly common check. +// +var NotNil Checker = ¬NilChecker{ + &CheckerInfo{Name: "NotNil", Params: []string{"value"}}, +} + +func (checker *notNilChecker) Check(params []interface{}, names []string) (result bool, error string) { + return !isNil(params[0]), "" +} + +// ----------------------------------------------------------------------- +// Equals checker. + +type equalsChecker struct { + *CheckerInfo +} + +// The Equals checker verifies that the obtained value is equal to +// the expected value, according to usual Go semantics for ==. +// +// For example: +// +// c.Assert(value, Equals, 42) +// +var Equals Checker = &equalsChecker{ + &CheckerInfo{Name: "Equals", Params: []string{"obtained", "expected"}}, +} + +func (checker *equalsChecker) Check(params []interface{}, names []string) (result bool, error string) { + defer func() { + if v := recover(); v != nil { + result = false + error = fmt.Sprint(v) + } + }() + return params[0] == params[1], "" +} + +// ----------------------------------------------------------------------- +// DeepEquals checker. + +type deepEqualsChecker struct { + *CheckerInfo +} + +// The DeepEquals checker verifies that the obtained value is deep-equal to +// the expected value. The check will work correctly even when facing +// slices, interfaces, and values of different types (which always fail +// the test). +// +// For example: +// +// c.Assert(value, DeepEquals, 42) +// c.Assert(array, DeepEquals, []string{"hi", "there"}) +// +var DeepEquals Checker = &deepEqualsChecker{ + &CheckerInfo{Name: "DeepEquals", Params: []string{"obtained", "expected"}}, +} + +func (checker *deepEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) { + return reflect.DeepEqual(params[0], params[1]), "" +} + +// ----------------------------------------------------------------------- +// HasLen checker. + +type hasLenChecker struct { + *CheckerInfo +} + +// The HasLen checker verifies that the obtained value has the +// provided length. In many cases this is superior to using Equals +// in conjuction with the len function because in case the check +// fails the value itself will be printed, instead of its length, +// providing more details for figuring the problem. +// +// For example: +// +// c.Assert(list, HasLen, 5) +// +var HasLen Checker = &hasLenChecker{ + &CheckerInfo{Name: "HasLen", Params: []string{"obtained", "n"}}, +} + +func (checker *hasLenChecker) Check(params []interface{}, names []string) (result bool, error string) { + n, ok := params[1].(int) + if !ok { + return false, "n must be an int" + } + value := reflect.ValueOf(params[0]) + switch value.Kind() { + case reflect.Map, reflect.Array, reflect.Slice, reflect.Chan, reflect.String: + default: + return false, "obtained value type has no length" + } + return value.Len() == n, "" +} + +// ----------------------------------------------------------------------- +// ErrorMatches checker. + +type errorMatchesChecker struct { + *CheckerInfo +} + +// The ErrorMatches checker verifies that the error value +// is non nil and matches the regular expression provided. +// +// For example: +// +// c.Assert(err, ErrorMatches, "perm.*denied") +// +var ErrorMatches Checker = errorMatchesChecker{ + &CheckerInfo{Name: "ErrorMatches", Params: []string{"value", "regex"}}, +} + +func (checker errorMatchesChecker) Check(params []interface{}, names []string) (result bool, errStr string) { + if params[0] == nil { + return false, "Error value is nil" + } + err, ok := params[0].(error) + if !ok { + return false, "Value is not an error" + } + params[0] = err.Error() + names[0] = "error" + return matches(params[0], params[1]) +} + +// ----------------------------------------------------------------------- +// Matches checker. + +type matchesChecker struct { + *CheckerInfo +} + +// The Matches checker verifies that the string provided as the obtained +// value (or the string resulting from obtained.String()) matches the +// regular expression provided. +// +// For example: +// +// c.Assert(err, Matches, "perm.*denied") +// +var Matches Checker = &matchesChecker{ + &CheckerInfo{Name: "Matches", Params: []string{"value", "regex"}}, +} + +func (checker *matchesChecker) Check(params []interface{}, names []string) (result bool, error string) { + return matches(params[0], params[1]) +} + +func matches(value, regex interface{}) (result bool, error string) { + reStr, ok := regex.(string) + if !ok { + return false, "Regex must be a string" + } + valueStr, valueIsStr := value.(string) + if !valueIsStr { + if valueWithStr, valueHasStr := value.(fmt.Stringer); valueHasStr { + valueStr, valueIsStr = valueWithStr.String(), true + } + } + if valueIsStr { + matches, err := regexp.MatchString("^"+reStr+"$", valueStr) + if err != nil { + return false, "Can't compile regex: " + err.Error() + } + return matches, "" + } + return false, "Obtained value is not a string and has no .String()" +} + +// ----------------------------------------------------------------------- +// Panics checker. + +type panicsChecker struct { + *CheckerInfo +} + +// The Panics checker verifies that calling the provided zero-argument +// function will cause a panic which is deep-equal to the provided value. +// +// For example: +// +// c.Assert(func() { f(1, 2) }, Panics, &SomeErrorType{"BOOM"}). +// +// +var Panics Checker = &panicsChecker{ + &CheckerInfo{Name: "Panics", Params: []string{"function", "expected"}}, +} + +func (checker *panicsChecker) Check(params []interface{}, names []string) (result bool, error string) { + f := reflect.ValueOf(params[0]) + if f.Kind() != reflect.Func || f.Type().NumIn() != 0 { + return false, "Function must take zero arguments" + } + defer func() { + // If the function has not panicked, then don't do the check. + if error != "" { + return + } + params[0] = recover() + names[0] = "panic" + result = reflect.DeepEqual(params[0], params[1]) + }() + f.Call(nil) + return false, "Function has not panicked" +} + +type panicMatchesChecker struct { + *CheckerInfo +} + +// The PanicMatches checker verifies that calling the provided zero-argument +// function will cause a panic with an error value matching +// the regular expression provided. +// +// For example: +// +// c.Assert(func() { f(1, 2) }, PanicMatches, `open.*: no such file or directory`). +// +// +var PanicMatches Checker = &panicMatchesChecker{ + &CheckerInfo{Name: "PanicMatches", Params: []string{"function", "expected"}}, +} + +func (checker *panicMatchesChecker) Check(params []interface{}, names []string) (result bool, errmsg string) { + f := reflect.ValueOf(params[0]) + if f.Kind() != reflect.Func || f.Type().NumIn() != 0 { + return false, "Function must take zero arguments" + } + defer func() { + // If the function has not panicked, then don't do the check. + if errmsg != "" { + return + } + obtained := recover() + names[0] = "panic" + if e, ok := obtained.(error); ok { + params[0] = e.Error() + } else if _, ok := obtained.(string); ok { + params[0] = obtained + } else { + errmsg = "Panic value is not a string or an error" + return + } + result, errmsg = matches(params[0], params[1]) + }() + f.Call(nil) + return false, "Function has not panicked" +} + +// ----------------------------------------------------------------------- +// FitsTypeOf checker. + +type fitsTypeChecker struct { + *CheckerInfo +} + +// The FitsTypeOf checker verifies that the obtained value is +// assignable to a variable with the same type as the provided +// sample value. +// +// For example: +// +// c.Assert(value, FitsTypeOf, int64(0)) +// c.Assert(value, FitsTypeOf, os.Error(nil)) +// +var FitsTypeOf Checker = &fitsTypeChecker{ + &CheckerInfo{Name: "FitsTypeOf", Params: []string{"obtained", "sample"}}, +} + +func (checker *fitsTypeChecker) Check(params []interface{}, names []string) (result bool, error string) { + obtained := reflect.ValueOf(params[0]) + sample := reflect.ValueOf(params[1]) + if !obtained.IsValid() { + return false, "" + } + if !sample.IsValid() { + return false, "Invalid sample value" + } + return obtained.Type().AssignableTo(sample.Type()), "" +} + +// ----------------------------------------------------------------------- +// Implements checker. + +type implementsChecker struct { + *CheckerInfo +} + +// The Implements checker verifies that the obtained value +// implements the interface specified via a pointer to an interface +// variable. +// +// For example: +// +// var e os.Error +// c.Assert(err, Implements, &e) +// +var Implements Checker = &implementsChecker{ + &CheckerInfo{Name: "Implements", Params: []string{"obtained", "ifaceptr"}}, +} + +func (checker *implementsChecker) Check(params []interface{}, names []string) (result bool, error string) { + obtained := reflect.ValueOf(params[0]) + ifaceptr := reflect.ValueOf(params[1]) + if !obtained.IsValid() { + return false, "" + } + if !ifaceptr.IsValid() || ifaceptr.Kind() != reflect.Ptr || ifaceptr.Elem().Kind() != reflect.Interface { + return false, "ifaceptr should be a pointer to an interface variable" + } + return obtained.Type().Implements(ifaceptr.Elem().Type()), "" +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers2.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers2.go new file mode 100644 index 0000000..c09bcdc --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/checkers2.go @@ -0,0 +1,131 @@ +// Extensions to the go-check unittest framework. +// +// NOTE: see https://github.com/go-check/check/pull/6 for reasons why these +// checkers live here. +package check + +import ( + "bytes" + "reflect" +) + +// ----------------------------------------------------------------------- +// IsTrue / IsFalse checker. + +type isBoolValueChecker struct { + *CheckerInfo + expected bool +} + +func (checker *isBoolValueChecker) Check( + params []interface{}, + names []string) ( + result bool, + error string) { + + obtained, ok := params[0].(bool) + if !ok { + return false, "Argument to " + checker.Name + " must be bool" + } + + return obtained == checker.expected, "" +} + +// The IsTrue checker verifies that the obtained value is true. +// +// For example: +// +// c.Assert(value, IsTrue) +// +var IsTrue Checker = &isBoolValueChecker{ + &CheckerInfo{Name: "IsTrue", Params: []string{"obtained"}}, + true, +} + +// The IsFalse checker verifies that the obtained value is false. +// +// For example: +// +// c.Assert(value, IsFalse) +// +var IsFalse Checker = &isBoolValueChecker{ + &CheckerInfo{Name: "IsFalse", Params: []string{"obtained"}}, + false, +} + +// ----------------------------------------------------------------------- +// BytesEquals checker. + +type bytesEquals struct{} + +func (b *bytesEquals) Check(params []interface{}, names []string) (bool, string) { + if len(params) != 2 { + return false, "BytesEqual takes 2 bytestring arguments" + } + b1, ok1 := params[0].([]byte) + b2, ok2 := params[1].([]byte) + + if !(ok1 && ok2) { + return false, "Arguments to BytesEqual must both be bytestrings" + } + + return bytes.Equal(b1, b2), "" +} + +func (b *bytesEquals) Info() *CheckerInfo { + return &CheckerInfo{ + Name: "BytesEquals", + Params: []string{"bytes_one", "bytes_two"}, + } +} + +// BytesEquals checker compares two bytes sequence using bytes.Equal. +// +// For example: +// +// c.Assert(b, BytesEquals, []byte("bar")) +// +// Main difference between DeepEquals and BytesEquals is that BytesEquals treats +// `nil` as empty byte sequence while DeepEquals doesn't. +// +// c.Assert(nil, BytesEquals, []byte("")) // succeeds +// c.Assert(nil, DeepEquals, []byte("")) // fails +var BytesEquals = &bytesEquals{} + +// ----------------------------------------------------------------------- +// HasKey checker. + +type hasKey struct{} + +func (h *hasKey) Check(params []interface{}, names []string) (bool, string) { + if len(params) != 2 { + return false, "HasKey takes 2 arguments: a map and a key" + } + + mapValue := reflect.ValueOf(params[0]) + if mapValue.Kind() != reflect.Map { + return false, "First argument to HasKey must be a map" + } + + keyValue := reflect.ValueOf(params[1]) + if !keyValue.Type().AssignableTo(mapValue.Type().Key()) { + return false, "Second argument must be assignable to the map key type" + } + + return mapValue.MapIndex(keyValue).IsValid(), "" +} + +func (h *hasKey) Info() *CheckerInfo { + return &CheckerInfo{ + Name: "HasKey", + Params: []string{"obtained", "key"}, + } +} + +// The HasKey checker verifies that the obtained map contains the given key. +// +// For example: +// +// c.Assert(myMap, HasKey, "foo") +// +var HasKey = &hasKey{} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/compare.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/compare.go new file mode 100644 index 0000000..7005cba --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/compare.go @@ -0,0 +1,161 @@ +package check + +import ( + "bytes" + "fmt" + "reflect" + "time" +) + +type compareFunc func(v1 interface{}, v2 interface{}) (bool, error) + +type valueCompare struct { + Name string + + Func compareFunc + + Operator string +} + +// v1 and v2 must have the same type +// return >0 if v1 > v2 +// return 0 if v1 = v2 +// return <0 if v1 < v2 +// now we only support int, uint, float64, string and []byte comparison +func compare(v1 interface{}, v2 interface{}) (int, error) { + value1 := reflect.ValueOf(v1) + value2 := reflect.ValueOf(v2) + + switch v1.(type) { + case int, int8, int16, int32, int64: + a1 := value1.Int() + a2 := value2.Int() + if a1 > a2 { + return 1, nil + } else if a1 == a2 { + return 0, nil + } + return -1, nil + case uint, uint8, uint16, uint32, uint64: + a1 := value1.Uint() + a2 := value2.Uint() + if a1 > a2 { + return 1, nil + } else if a1 == a2 { + return 0, nil + } + return -1, nil + case float32, float64: + a1 := value1.Float() + a2 := value2.Float() + if a1 > a2 { + return 1, nil + } else if a1 == a2 { + return 0, nil + } + return -1, nil + case string: + a1 := value1.String() + a2 := value2.String() + if a1 > a2 { + return 1, nil + } else if a1 == a2 { + return 0, nil + } + return -1, nil + case []byte: + a1 := value1.Bytes() + a2 := value2.Bytes() + return bytes.Compare(a1, a2), nil + case time.Time: + a1 := v1.(time.Time) + a2 := v2.(time.Time) + if a1.After(a2) { + return 1, nil + } else if a1.Equal(a2) { + return 0, nil + } + return -1, nil + case time.Duration: + a1 := v1.(time.Duration) + a2 := v2.(time.Duration) + if a1 > a2 { + return 1, nil + } else if a1 == a2 { + return 0, nil + } + return -1, nil + default: + return 0, fmt.Errorf("type %T is not supported now", v1) + } +} + +func less(v1 interface{}, v2 interface{}) (bool, error) { + n, err := compare(v1, v2) + if err != nil { + return false, err + } + + return n < 0, nil +} + +func lessEqual(v1 interface{}, v2 interface{}) (bool, error) { + n, err := compare(v1, v2) + if err != nil { + return false, err + } + + return n <= 0, nil +} + +func greater(v1 interface{}, v2 interface{}) (bool, error) { + n, err := compare(v1, v2) + if err != nil { + return false, err + } + + return n > 0, nil +} + +func greaterEqual(v1 interface{}, v2 interface{}) (bool, error) { + n, err := compare(v1, v2) + if err != nil { + return false, err + } + + return n >= 0, nil +} + +func (v *valueCompare) Check(params []interface{}, names []string) (bool, string) { + if len(params) != 2 { + return false, fmt.Sprintf("%s needs 2 arguments", v.Name) + } + + v1 := params[0] + v2 := params[1] + v1Type := reflect.TypeOf(v1) + v2Type := reflect.TypeOf(v2) + + if v1Type.Kind() != v2Type.Kind() { + return false, fmt.Sprintf("%s needs two same type, but %s != %s", v.Name, v1Type.Kind(), v2Type.Kind()) + } + + b, err := v.Func(v1, v2) + if err != nil { + return false, fmt.Sprintf("%s check err %v", v.Name, err) + } + + return b, "" +} + +func (v *valueCompare) Info() *CheckerInfo { + return &CheckerInfo{ + Name: v.Name, + Params: []string{"compare_one", "compare_two"}, + } +} + +var Less = &valueCompare{Name: "Less", Func: less, Operator: "<"} +var LessEqual = &valueCompare{Name: "LessEqual", Func: lessEqual, Operator: "<="} +var Greater = &valueCompare{Name: "Greater", Func: greater, Operator: ">"} +var GreaterEqual = &valueCompare{Name: "GreaterEqual", Func: greaterEqual, Operator: ">="} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/helpers.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/helpers.go new file mode 100644 index 0000000..58a733b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/helpers.go @@ -0,0 +1,231 @@ +package check + +import ( + "fmt" + "strings" + "time" +) + +// TestName returns the current test name in the form "SuiteName.TestName" +func (c *C) TestName() string { + return c.testName +} + +// ----------------------------------------------------------------------- +// Basic succeeding/failing logic. + +// Failed returns whether the currently running test has already failed. +func (c *C) Failed() bool { + return c.status() == failedSt +} + +// Fail marks the currently running test as failed. +// +// Something ought to have been previously logged so the developer can tell +// what went wrong. The higher level helper functions will fail the test +// and do the logging properly. +func (c *C) Fail() { + c.setStatus(failedSt) +} + +// FailNow marks the currently running test as failed and stops running it. +// Something ought to have been previously logged so the developer can tell +// what went wrong. The higher level helper functions will fail the test +// and do the logging properly. +func (c *C) FailNow() { + c.Fail() + c.stopNow() +} + +// Succeed marks the currently running test as succeeded, undoing any +// previous failures. +func (c *C) Succeed() { + c.setStatus(succeededSt) +} + +// SucceedNow marks the currently running test as succeeded, undoing any +// previous failures, and stops running the test. +func (c *C) SucceedNow() { + c.Succeed() + c.stopNow() +} + +// ExpectFailure informs that the running test is knowingly broken for +// the provided reason. If the test does not fail, an error will be reported +// to raise attention to this fact. This method is useful to temporarily +// disable tests which cover well known problems until a better time to +// fix the problem is found, without forgetting about the fact that a +// failure still exists. +func (c *C) ExpectFailure(reason string) { + if reason == "" { + panic("Missing reason why the test is expected to fail") + } + c.mustFail = true + c.reason = reason +} + +// Skip skips the running test for the provided reason. If run from within +// SetUpTest, the individual test being set up will be skipped, and if run +// from within SetUpSuite, the whole suite is skipped. +func (c *C) Skip(reason string) { + if reason == "" { + panic("Missing reason why the test is being skipped") + } + c.reason = reason + c.setStatus(skippedSt) + c.stopNow() +} + +// ----------------------------------------------------------------------- +// Basic logging. + +// GetTestLog returns the current test error output. +func (c *C) GetTestLog() string { + return c.logb.String() +} + +// Log logs some information into the test error output. +// The provided arguments are assembled together into a string with fmt.Sprint. +func (c *C) Log(args ...interface{}) { + c.log(args...) +} + +// Log logs some information into the test error output. +// The provided arguments are assembled together into a string with fmt.Sprintf. +func (c *C) Logf(format string, args ...interface{}) { + c.logf(format, args...) +} + +// Output enables *C to be used as a logger in functions that require only +// the minimum interface of *log.Logger. +func (c *C) Output(calldepth int, s string) error { + d := time.Now().Sub(c.startTime) + msec := d / time.Millisecond + sec := d / time.Second + min := d / time.Minute + + c.Logf("[LOG] %d:%02d.%03d %s", min, sec%60, msec%1000, s) + return nil +} + +// Error logs an error into the test error output and marks the test as failed. +// The provided arguments are assembled together into a string with fmt.Sprint. +func (c *C) Error(args ...interface{}) { + c.logCaller(1) + c.logString(fmt.Sprint("Error: ", fmt.Sprint(args...))) + c.logNewLine() + c.Fail() +} + +// Errorf logs an error into the test error output and marks the test as failed. +// The provided arguments are assembled together into a string with fmt.Sprintf. +func (c *C) Errorf(format string, args ...interface{}) { + c.logCaller(1) + c.logString(fmt.Sprintf("Error: "+format, args...)) + c.logNewLine() + c.Fail() +} + +// Fatal logs an error into the test error output, marks the test as failed, and +// stops the test execution. The provided arguments are assembled together into +// a string with fmt.Sprint. +func (c *C) Fatal(args ...interface{}) { + c.logCaller(1) + c.logString(fmt.Sprint("Error: ", fmt.Sprint(args...))) + c.logNewLine() + c.FailNow() +} + +// Fatlaf logs an error into the test error output, marks the test as failed, and +// stops the test execution. The provided arguments are assembled together into +// a string with fmt.Sprintf. +func (c *C) Fatalf(format string, args ...interface{}) { + c.logCaller(1) + c.logString(fmt.Sprint("Error: ", fmt.Sprintf(format, args...))) + c.logNewLine() + c.FailNow() +} + +// ----------------------------------------------------------------------- +// Generic checks and assertions based on checkers. + +// Check verifies if the first value matches the expected value according +// to the provided checker. If they do not match, an error is logged, the +// test is marked as failed, and the test execution continues. +// +// Some checkers may not need the expected argument (e.g. IsNil). +// +// Extra arguments provided to the function are logged next to the reported +// problem when the matching fails. +func (c *C) Check(obtained interface{}, checker Checker, args ...interface{}) bool { + return c.internalCheck("Check", obtained, checker, args...) +} + +// Assert ensures that the first value matches the expected value according +// to the provided checker. If they do not match, an error is logged, the +// test is marked as failed, and the test execution stops. +// +// Some checkers may not need the expected argument (e.g. IsNil). +// +// Extra arguments provided to the function are logged next to the reported +// problem when the matching fails. +func (c *C) Assert(obtained interface{}, checker Checker, args ...interface{}) { + if !c.internalCheck("Assert", obtained, checker, args...) { + c.stopNow() + } +} + +func (c *C) internalCheck(funcName string, obtained interface{}, checker Checker, args ...interface{}) bool { + if checker == nil { + c.logCaller(2) + c.logString(fmt.Sprintf("%s(obtained, nil!?, ...):", funcName)) + c.logString("Oops.. you've provided a nil checker!") + c.logNewLine() + c.Fail() + return false + } + + // If the last argument is a bug info, extract it out. + var comment CommentInterface + if len(args) > 0 { + if c, ok := args[len(args)-1].(CommentInterface); ok { + comment = c + args = args[:len(args)-1] + } + } + + params := append([]interface{}{obtained}, args...) + info := checker.Info() + + if len(params) != len(info.Params) { + names := append([]string{info.Params[0], info.Name}, info.Params[1:]...) + c.logCaller(2) + c.logString(fmt.Sprintf("%s(%s):", funcName, strings.Join(names, ", "))) + c.logString(fmt.Sprintf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(names), len(params)+1)) + c.logNewLine() + c.Fail() + return false + } + + // Copy since it may be mutated by Check. + names := append([]string{}, info.Params...) + + // Do the actual check. + result, error := checker.Check(params, names) + if !result || error != "" { + c.logCaller(2) + for i := 0; i != len(params); i++ { + c.logValue(names[i], params[i]) + } + if comment != nil { + c.logString(comment.CheckCommentString()) + } + if error != "" { + c.logString(error) + } + c.logNewLine() + c.Fail() + return false + } + return true +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/printer.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/printer.go new file mode 100644 index 0000000..e0f7557 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/printer.go @@ -0,0 +1,168 @@ +package check + +import ( + "bytes" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "os" +) + +func indent(s, with string) (r string) { + eol := true + for i := 0; i != len(s); i++ { + c := s[i] + switch { + case eol && c == '\n' || c == '\r': + case c == '\n' || c == '\r': + eol = true + case eol: + eol = false + s = s[:i] + with + s[i:] + i += len(with) + } + } + return s +} + +func printLine(filename string, line int) (string, error) { + fset := token.NewFileSet() + file, err := os.Open(filename) + if err != nil { + return "", err + } + fnode, err := parser.ParseFile(fset, filename, file, parser.ParseComments) + if err != nil { + return "", err + } + config := &printer.Config{Mode: printer.UseSpaces, Tabwidth: 4} + lp := &linePrinter{fset: fset, fnode: fnode, line: line, config: config} + ast.Walk(lp, fnode) + result := lp.output.Bytes() + // Comments leave \n at the end. + n := len(result) + for n > 0 && result[n-1] == '\n' { + n-- + } + return string(result[:n]), nil +} + +type linePrinter struct { + config *printer.Config + fset *token.FileSet + fnode *ast.File + line int + output bytes.Buffer + stmt ast.Stmt +} + +func (lp *linePrinter) emit() bool { + if lp.stmt != nil { + lp.trim(lp.stmt) + lp.printWithComments(lp.stmt) + lp.stmt = nil + return true + } + return false +} + +func (lp *linePrinter) printWithComments(n ast.Node) { + nfirst := lp.fset.Position(n.Pos()).Line + nlast := lp.fset.Position(n.End()).Line + for _, g := range lp.fnode.Comments { + cfirst := lp.fset.Position(g.Pos()).Line + clast := lp.fset.Position(g.End()).Line + if clast == nfirst-1 && lp.fset.Position(n.Pos()).Column == lp.fset.Position(g.Pos()).Column { + for _, c := range g.List { + lp.output.WriteString(c.Text) + lp.output.WriteByte('\n') + } + } + if cfirst >= nfirst && cfirst <= nlast && n.End() <= g.List[0].Slash { + // The printer will not include the comment if it starts past + // the node itself. Trick it into printing by overlapping the + // slash with the end of the statement. + g.List[0].Slash = n.End() - 1 + } + } + node := &printer.CommentedNode{n, lp.fnode.Comments} + lp.config.Fprint(&lp.output, lp.fset, node) +} + +func (lp *linePrinter) Visit(n ast.Node) (w ast.Visitor) { + if n == nil { + if lp.output.Len() == 0 { + lp.emit() + } + return nil + } + first := lp.fset.Position(n.Pos()).Line + last := lp.fset.Position(n.End()).Line + if first <= lp.line && last >= lp.line { + // Print the innermost statement containing the line. + if stmt, ok := n.(ast.Stmt); ok { + if _, ok := n.(*ast.BlockStmt); !ok { + lp.stmt = stmt + } + } + if first == lp.line && lp.emit() { + return nil + } + return lp + } + return nil +} + +func (lp *linePrinter) trim(n ast.Node) bool { + stmt, ok := n.(ast.Stmt) + if !ok { + return true + } + line := lp.fset.Position(n.Pos()).Line + if line != lp.line { + return false + } + switch stmt := stmt.(type) { + case *ast.IfStmt: + stmt.Body = lp.trimBlock(stmt.Body) + case *ast.SwitchStmt: + stmt.Body = lp.trimBlock(stmt.Body) + case *ast.TypeSwitchStmt: + stmt.Body = lp.trimBlock(stmt.Body) + case *ast.CaseClause: + stmt.Body = lp.trimList(stmt.Body) + case *ast.CommClause: + stmt.Body = lp.trimList(stmt.Body) + case *ast.BlockStmt: + stmt.List = lp.trimList(stmt.List) + } + return true +} + +func (lp *linePrinter) trimBlock(stmt *ast.BlockStmt) *ast.BlockStmt { + if !lp.trim(stmt) { + return lp.emptyBlock(stmt) + } + stmt.Rbrace = stmt.Lbrace + return stmt +} + +func (lp *linePrinter) trimList(stmts []ast.Stmt) []ast.Stmt { + for i := 0; i != len(stmts); i++ { + if !lp.trim(stmts[i]) { + stmts[i] = lp.emptyStmt(stmts[i]) + break + } + } + return stmts +} + +func (lp *linePrinter) emptyStmt(n ast.Node) *ast.ExprStmt { + return &ast.ExprStmt{&ast.Ellipsis{n.Pos(), nil}} +} + +func (lp *linePrinter) emptyBlock(n ast.Node) *ast.BlockStmt { + p := n.Pos() + return &ast.BlockStmt{p, []ast.Stmt{lp.emptyStmt(n)}, p} +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/run.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/run.go new file mode 100644 index 0000000..da8fd79 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/pingcap/check/run.go @@ -0,0 +1,175 @@ +package check + +import ( + "bufio" + "flag" + "fmt" + "os" + "testing" + "time" +) + +// ----------------------------------------------------------------------- +// Test suite registry. + +var allSuites []interface{} + +// Suite registers the given value as a test suite to be run. Any methods +// starting with the Test prefix in the given value will be considered as +// a test method. +func Suite(suite interface{}) interface{} { + allSuites = append(allSuites, suite) + return suite +} + +// ----------------------------------------------------------------------- +// Public running interface. + +var ( + oldFilterFlag = flag.String("gocheck.f", "", "Regular expression selecting which tests and/or suites to run") + oldVerboseFlag = flag.Bool("gocheck.v", false, "Verbose mode") + oldStreamFlag = flag.Bool("gocheck.vv", false, "Super verbose mode (disables output caching)") + oldBenchFlag = flag.Bool("gocheck.b", false, "Run benchmarks") + oldBenchTime = flag.Duration("gocheck.btime", 1*time.Second, "approximate run time for each benchmark") + oldListFlag = flag.Bool("gocheck.list", false, "List the names of all tests that will be run") + oldWorkFlag = flag.Bool("gocheck.work", false, "Display and do not remove the test working directory") + + newFilterFlag = flag.String("check.f", "", "Regular expression selecting which tests and/or suites to run") + newVerboseFlag = flag.Bool("check.v", false, "Verbose mode") + newStreamFlag = flag.Bool("check.vv", false, "Super verbose mode (disables output caching)") + newBenchFlag = flag.Bool("check.b", false, "Run benchmarks") + newBenchTime = flag.Duration("check.btime", 1*time.Second, "approximate run time for each benchmark") + newBenchMem = flag.Bool("check.bmem", false, "Report memory benchmarks") + newListFlag = flag.Bool("check.list", false, "List the names of all tests that will be run") + newWorkFlag = flag.Bool("check.work", false, "Display and do not remove the test working directory") +) + +// TestingT runs all test suites registered with the Suite function, +// printing results to stdout, and reporting any failures back to +// the "testing" package. +func TestingT(testingT *testing.T) { + benchTime := *newBenchTime + if benchTime == 1*time.Second { + benchTime = *oldBenchTime + } + conf := &RunConf{ + Filter: *oldFilterFlag + *newFilterFlag, + Verbose: *oldVerboseFlag || *newVerboseFlag, + Stream: *oldStreamFlag || *newStreamFlag, + Benchmark: *oldBenchFlag || *newBenchFlag, + BenchmarkTime: benchTime, + BenchmarkMem: *newBenchMem, + KeepWorkDir: *oldWorkFlag || *newWorkFlag, + } + if *oldListFlag || *newListFlag { + w := bufio.NewWriter(os.Stdout) + for _, name := range ListAll(conf) { + fmt.Fprintln(w, name) + } + w.Flush() + return + } + result := RunAll(conf) + println(result.String()) + if !result.Passed() { + testingT.Fail() + } +} + +// RunAll runs all test suites registered with the Suite function, using the +// provided run configuration. +func RunAll(runConf *RunConf) *Result { + result := Result{} + for _, suite := range allSuites { + result.Add(Run(suite, runConf)) + } + return &result +} + +// Run runs the provided test suite using the provided run configuration. +func Run(suite interface{}, runConf *RunConf) *Result { + runner := newSuiteRunner(suite, runConf) + return runner.run() +} + +// ListAll returns the names of all the test functions registered with the +// Suite function that will be run with the provided run configuration. +func ListAll(runConf *RunConf) []string { + var names []string + for _, suite := range allSuites { + names = append(names, List(suite, runConf)...) + } + return names +} + +// List returns the names of the test functions in the given +// suite that will be run with the provided run configuration. +func List(suite interface{}, runConf *RunConf) []string { + var names []string + runner := newSuiteRunner(suite, runConf) + for _, t := range runner.tests { + names = append(names, t.String()) + } + return names +} + +// ----------------------------------------------------------------------- +// Result methods. + +func (r *Result) Add(other *Result) { + r.Succeeded += other.Succeeded + r.Skipped += other.Skipped + r.Failed += other.Failed + r.Panicked += other.Panicked + r.FixturePanicked += other.FixturePanicked + r.ExpectedFailures += other.ExpectedFailures + r.Missed += other.Missed + if r.WorkDir != "" && other.WorkDir != "" { + r.WorkDir += ":" + other.WorkDir + } else if other.WorkDir != "" { + r.WorkDir = other.WorkDir + } +} + +func (r *Result) Passed() bool { + return (r.Failed == 0 && r.Panicked == 0 && + r.FixturePanicked == 0 && r.Missed == 0 && + r.RunError == nil) +} + +func (r *Result) String() string { + if r.RunError != nil { + return "ERROR: " + r.RunError.Error() + } + + var value string + if r.Failed == 0 && r.Panicked == 0 && r.FixturePanicked == 0 && + r.Missed == 0 { + value = "OK: " + } else { + value = "OOPS: " + } + value += fmt.Sprintf("%d passed", r.Succeeded) + if r.Skipped != 0 { + value += fmt.Sprintf(", %d skipped", r.Skipped) + } + if r.ExpectedFailures != 0 { + value += fmt.Sprintf(", %d expected failures", r.ExpectedFailures) + } + if r.Failed != 0 { + value += fmt.Sprintf(", %d FAILED", r.Failed) + } + if r.Panicked != 0 { + value += fmt.Sprintf(", %d PANICKED", r.Panicked) + } + if r.FixturePanicked != 0 { + value += fmt.Sprintf(", %d FIXTURE-PANICKED", r.FixturePanicked) + } + if r.Missed != 0 { + value += fmt.Sprintf(", %d MISSED", r.Missed) + } + if r.WorkDir != "" { + value += "\nWORK=" + r.WorkDir + } + return value +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/LICENSE new file mode 100644 index 0000000..488357b --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/LICENSE @@ -0,0 +1,20 @@ +Copyright (C) 2013-2016 by Maxim Bublis + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/uuid.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/uuid.go new file mode 100644 index 0000000..9c7fbaa --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/satori/go.uuid/uuid.go @@ -0,0 +1,488 @@ +// Copyright (C) 2013-2015 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Package uuid provides implementation of Universally Unique Identifier (UUID). +// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and +// version 2 (as specified in DCE 1.1). +package uuid + +import ( + "bytes" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "database/sql/driver" + "encoding/binary" + "encoding/hex" + "fmt" + "hash" + "net" + "os" + "sync" + "time" +) + +// UUID layout variants. +const ( + VariantNCS = iota + VariantRFC4122 + VariantMicrosoft + VariantFuture +) + +// UUID DCE domains. +const ( + DomainPerson = iota + DomainGroup + DomainOrg +) + +// Difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970). +const epochStart = 122192928000000000 + +// Used in string method conversion +const dash byte = '-' + +// UUID v1/v2 storage. +var ( + storageMutex sync.Mutex + storageOnce sync.Once + epochFunc = unixTimeFunc + clockSequence uint16 + lastTime uint64 + hardwareAddr [6]byte + posixUID = uint32(os.Getuid()) + posixGID = uint32(os.Getgid()) +) + +// String parse helpers. +var ( + urnPrefix = []byte("urn:uuid:") + byteGroups = []int{8, 4, 4, 4, 12} +) + +func initClockSequence() { + buf := make([]byte, 2) + safeRandom(buf) + clockSequence = binary.BigEndian.Uint16(buf) +} + +func initHardwareAddr() { + interfaces, err := net.Interfaces() + if err == nil { + for _, iface := range interfaces { + if len(iface.HardwareAddr) >= 6 { + copy(hardwareAddr[:], iface.HardwareAddr) + return + } + } + } + + // Initialize hardwareAddr randomly in case + // of real network interfaces absence + safeRandom(hardwareAddr[:]) + + // Set multicast bit as recommended in RFC 4122 + hardwareAddr[0] |= 0x01 +} + +func initStorage() { + initClockSequence() + initHardwareAddr() +} + +func safeRandom(dest []byte) { + if _, err := rand.Read(dest); err != nil { + panic(err) + } +} + +// Returns difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and current time. +// This is default epoch calculation function. +func unixTimeFunc() uint64 { + return epochStart + uint64(time.Now().UnixNano()/100) +} + +// UUID representation compliant with specification +// described in RFC 4122. +type UUID [16]byte + +// NullUUID can be used with the standard sql package to represent a +// UUID value that can be NULL in the database +type NullUUID struct { + UUID UUID + Valid bool +} + +// The nil UUID is special form of UUID that is specified to have all +// 128 bits set to zero. +var Nil = UUID{} + +// Predefined namespace UUIDs. +var ( + NamespaceDNS, _ = FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + NamespaceURL, _ = FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8") + NamespaceOID, _ = FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8") + NamespaceX500, _ = FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8") +) + +// And returns result of binary AND of two UUIDs. +func And(u1 UUID, u2 UUID) UUID { + u := UUID{} + for i := 0; i < 16; i++ { + u[i] = u1[i] & u2[i] + } + return u +} + +// Or returns result of binary OR of two UUIDs. +func Or(u1 UUID, u2 UUID) UUID { + u := UUID{} + for i := 0; i < 16; i++ { + u[i] = u1[i] | u2[i] + } + return u +} + +// Equal returns true if u1 and u2 equals, otherwise returns false. +func Equal(u1 UUID, u2 UUID) bool { + return bytes.Equal(u1[:], u2[:]) +} + +// Version returns algorithm version used to generate UUID. +func (u UUID) Version() uint { + return uint(u[6] >> 4) +} + +// Variant returns UUID layout variant. +func (u UUID) Variant() uint { + switch { + case (u[8] & 0x80) == 0x00: + return VariantNCS + case (u[8]&0xc0)|0x80 == 0x80: + return VariantRFC4122 + case (u[8]&0xe0)|0xc0 == 0xc0: + return VariantMicrosoft + } + return VariantFuture +} + +// Bytes returns bytes slice representation of UUID. +func (u UUID) Bytes() []byte { + return u[:] +} + +// Returns canonical string representation of UUID: +// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. +func (u UUID) String() string { + buf := make([]byte, 36) + + hex.Encode(buf[0:8], u[0:4]) + buf[8] = dash + hex.Encode(buf[9:13], u[4:6]) + buf[13] = dash + hex.Encode(buf[14:18], u[6:8]) + buf[18] = dash + hex.Encode(buf[19:23], u[8:10]) + buf[23] = dash + hex.Encode(buf[24:], u[10:]) + + return string(buf) +} + +// SetVersion sets version bits. +func (u *UUID) SetVersion(v byte) { + u[6] = (u[6] & 0x0f) | (v << 4) +} + +// SetVariant sets variant bits as described in RFC 4122. +func (u *UUID) SetVariant() { + u[8] = (u[8] & 0xbf) | 0x80 +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The encoding is the same as returned by String. +func (u UUID) MarshalText() (text []byte, err error) { + text = []byte(u.String()) + return +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// Following formats are supported: +// "6ba7b810-9dad-11d1-80b4-00c04fd430c8", +// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}", +// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" +func (u *UUID) UnmarshalText(text []byte) (err error) { + if len(text) < 32 { + err = fmt.Errorf("uuid: UUID string too short: %s", text) + return + } + + t := text[:] + braced := false + + if bytes.Equal(t[:9], urnPrefix) { + t = t[9:] + } else if t[0] == '{' { + braced = true + t = t[1:] + } + + b := u[:] + + for i, byteGroup := range byteGroups { + if i > 0 && t[0] == '-' { + t = t[1:] + } else if i > 0 && t[0] != '-' { + err = fmt.Errorf("uuid: invalid string format") + return + } + + if i == 2 { + if !bytes.Contains([]byte("012345"), []byte{t[0]}) { + err = fmt.Errorf("uuid: invalid version number: %s", t[0]) + return + } + } + + if len(t) < byteGroup { + err = fmt.Errorf("uuid: UUID string too short: %s", text) + return + } + + if i == 4 && len(t) > byteGroup && + ((braced && t[byteGroup] != '}') || len(t[byteGroup:]) > 1 || !braced) { + err = fmt.Errorf("uuid: UUID string too long: %s", t) + return + } + + _, err = hex.Decode(b[:byteGroup/2], t[:byteGroup]) + + if err != nil { + return + } + + t = t[byteGroup:] + b = b[byteGroup/2:] + } + + return +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (u UUID) MarshalBinary() (data []byte, err error) { + data = u.Bytes() + return +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It will return error if the slice isn't 16 bytes long. +func (u *UUID) UnmarshalBinary(data []byte) (err error) { + if len(data) != 16 { + err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data)) + return + } + copy(u[:], data) + + return +} + +// Value implements the driver.Valuer interface. +func (u UUID) Value() (driver.Value, error) { + return u.String(), nil +} + +// Scan implements the sql.Scanner interface. +// A 16-byte slice is handled by UnmarshalBinary, while +// a longer byte slice or a string is handled by UnmarshalText. +func (u *UUID) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + if len(src) == 16 { + return u.UnmarshalBinary(src) + } + return u.UnmarshalText(src) + + case string: + return u.UnmarshalText([]byte(src)) + } + + return fmt.Errorf("uuid: cannot convert %T to UUID", src) +} + +// Value implements the driver.Valuer interface. +func (u NullUUID) Value() (driver.Value, error) { + if !u.Valid { + return nil, nil + } + // Delegate to UUID Value function + return u.UUID.Value() +} + +// Scan implements the sql.Scanner interface. +func (u *NullUUID) Scan(src interface{}) error { + if src == nil { + u.UUID, u.Valid = Nil, false + return nil + } + + // Delegate to UUID Scan function + u.Valid = true + return u.UUID.Scan(src) +} + +// FromBytes returns UUID converted from raw byte slice input. +// It will return error if the slice isn't 16 bytes long. +func FromBytes(input []byte) (u UUID, err error) { + err = u.UnmarshalBinary(input) + return +} + +// FromBytesOrNil returns UUID converted from raw byte slice input. +// Same behavior as FromBytes, but returns a Nil UUID on error. +func FromBytesOrNil(input []byte) UUID { + uuid, err := FromBytes(input) + if err != nil { + return Nil + } + return uuid +} + +// FromString returns UUID parsed from string input. +// Input is expected in a form accepted by UnmarshalText. +func FromString(input string) (u UUID, err error) { + err = u.UnmarshalText([]byte(input)) + return +} + +// FromStringOrNil returns UUID parsed from string input. +// Same behavior as FromString, but returns a Nil UUID on error. +func FromStringOrNil(input string) UUID { + uuid, err := FromString(input) + if err != nil { + return Nil + } + return uuid +} + +// Returns UUID v1/v2 storage state. +// Returns epoch timestamp, clock sequence, and hardware address. +func getStorage() (uint64, uint16, []byte) { + storageOnce.Do(initStorage) + + storageMutex.Lock() + defer storageMutex.Unlock() + + timeNow := epochFunc() + // Clock changed backwards since last UUID generation. + // Should increase clock sequence. + if timeNow <= lastTime { + clockSequence++ + } + lastTime = timeNow + + return timeNow, clockSequence, hardwareAddr[:] +} + +// NewV1 returns UUID based on current timestamp and MAC address. +func NewV1() UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := getStorage() + + binary.BigEndian.PutUint32(u[0:], uint32(timeNow)) + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + + copy(u[10:], hardwareAddr) + + u.SetVersion(1) + u.SetVariant() + + return u +} + +// NewV2 returns DCE Security UUID based on POSIX UID/GID. +func NewV2(domain byte) UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := getStorage() + + switch domain { + case DomainPerson: + binary.BigEndian.PutUint32(u[0:], posixUID) + case DomainGroup: + binary.BigEndian.PutUint32(u[0:], posixGID) + } + + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + u[9] = domain + + copy(u[10:], hardwareAddr) + + u.SetVersion(2) + u.SetVariant() + + return u +} + +// NewV3 returns UUID based on MD5 hash of namespace UUID and name. +func NewV3(ns UUID, name string) UUID { + u := newFromHash(md5.New(), ns, name) + u.SetVersion(3) + u.SetVariant() + + return u +} + +// NewV4 returns random generated UUID. +func NewV4() UUID { + u := UUID{} + safeRandom(u[:]) + u.SetVersion(4) + u.SetVariant() + + return u +} + +// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name. +func NewV5(ns UUID, name string) UUID { + u := newFromHash(sha1.New(), ns, name) + u.SetVersion(5) + u.SetVariant() + + return u +} + +// Returns UUID based on hashing of namespace UUID and name. +func newFromHash(h hash.Hash, ns UUID, name string) UUID { + u := UUID{} + h.Write(ns[:]) + h.Write([]byte(name)) + copy(u[:], h.Sum(nil)) + + return u +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/LICENSE new file mode 100644 index 0000000..80511a0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014 siddontang + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/hack/hack.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/hack/hack.go new file mode 100644 index 0000000..74ee83c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/hack/hack.go @@ -0,0 +1,27 @@ +package hack + +import ( + "reflect" + "unsafe" +) + +// no copy to change slice to string +// use your own risk +func String(b []byte) (s string) { + pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pstring.Data = pbytes.Data + pstring.Len = pbytes.Len + return +} + +// no copy to change string to slice +// use your own risk +func Slice(s string) (b []byte) { + pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pbytes.Data = pstring.Data + pbytes.Len = pstring.Len + pbytes.Cap = pstring.Len + return +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/ioutil.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/ioutil.go new file mode 100644 index 0000000..c99c987 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/ioutil.go @@ -0,0 +1,39 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ioutil2 + +import ( + "io" + "io/ioutil" + "os" + "path" +) + +// Write file to temp and atomically move when everything else succeeds. +func WriteFileAtomic(filename string, data []byte, perm os.FileMode) error { + dir, name := path.Split(filename) + f, err := ioutil.TempFile(dir, name) + if err != nil { + return err + } + n, err := f.Write(data) + f.Close() + if err == nil && n < len(data) { + err = io.ErrShortWrite + } else { + err = os.Chmod(f.Name(), perm) + } + if err != nil { + os.Remove(f.Name()) + return err + } + return os.Rename(f.Name(), filename) +} + +// Check file exists or not +func FileExists(name string) bool { + _, err := os.Stat(name) + return !os.IsNotExist(err) +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/sectionwriter.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/sectionwriter.go new file mode 100644 index 0000000..c02ab0d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/ioutil2/sectionwriter.go @@ -0,0 +1,69 @@ +package ioutil2 + +import ( + "errors" + "io" +) + +var ErrExceedLimit = errors.New("write exceed limit") + +func NewSectionWriter(w io.WriterAt, off int64, n int64) *SectionWriter { + return &SectionWriter{w, off, off, off + n} +} + +type SectionWriter struct { + w io.WriterAt + base int64 + off int64 + limit int64 +} + +func (s *SectionWriter) Write(p []byte) (n int, err error) { + if s.off >= s.limit { + return 0, ErrExceedLimit + } + + if max := s.limit - s.off; int64(len(p)) > max { + return 0, ErrExceedLimit + } + + n, err = s.w.WriteAt(p, s.off) + s.off += int64(n) + return +} + +var errWhence = errors.New("Seek: invalid whence") +var errOffset = errors.New("Seek: invalid offset") + +func (s *SectionWriter) Seek(offset int64, whence int) (int64, error) { + switch whence { + default: + return 0, errWhence + case 0: + offset += s.base + case 1: + offset += s.off + case 2: + offset += s.limit + } + if offset < s.base { + return 0, errOffset + } + s.off = offset + return offset - s.base, nil +} + +func (s *SectionWriter) WriteAt(p []byte, off int64) (n int, err error) { + if off < 0 || off >= s.limit-s.base { + return 0, errOffset + } + off += s.base + if max := s.limit - off; int64(len(p)) > max { + return 0, ErrExceedLimit + } + + return s.w.WriteAt(p, off) +} + +// Size returns the size of the section in bytes. +func (s *SectionWriter) Size() int64 { return s.limit - s.base } diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/atomic.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/atomic.go new file mode 100644 index 0000000..424a974 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/atomic.go @@ -0,0 +1,146 @@ +// Copyright 2013, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "sync" + "sync/atomic" + "time" +) + +type AtomicInt32 int32 + +func (i *AtomicInt32) Add(n int32) int32 { + return atomic.AddInt32((*int32)(i), n) +} + +func (i *AtomicInt32) Set(n int32) { + atomic.StoreInt32((*int32)(i), n) +} + +func (i *AtomicInt32) Get() int32 { + return atomic.LoadInt32((*int32)(i)) +} + +func (i *AtomicInt32) CompareAndSwap(oldval, newval int32) (swapped bool) { + return atomic.CompareAndSwapInt32((*int32)(i), oldval, newval) +} + +type AtomicUint32 uint32 + +func (i *AtomicUint32) Add(n uint32) uint32 { + return atomic.AddUint32((*uint32)(i), n) +} + +func (i *AtomicUint32) Set(n uint32) { + atomic.StoreUint32((*uint32)(i), n) +} + +func (i *AtomicUint32) Get() uint32 { + return atomic.LoadUint32((*uint32)(i)) +} + +func (i *AtomicUint32) CompareAndSwap(oldval, newval uint32) (swapped bool) { + return atomic.CompareAndSwapUint32((*uint32)(i), oldval, newval) +} + +type AtomicInt64 int64 + +func (i *AtomicInt64) Add(n int64) int64 { + return atomic.AddInt64((*int64)(i), n) +} + +func (i *AtomicInt64) Set(n int64) { + atomic.StoreInt64((*int64)(i), n) +} + +func (i *AtomicInt64) Get() int64 { + return atomic.LoadInt64((*int64)(i)) +} + +func (i *AtomicInt64) CompareAndSwap(oldval, newval int64) (swapped bool) { + return atomic.CompareAndSwapInt64((*int64)(i), oldval, newval) +} + +type AtomicUint64 uint64 + +func (i *AtomicUint64) Add(n uint64) uint64 { + return atomic.AddUint64((*uint64)(i), n) +} + +func (i *AtomicUint64) Set(n uint64) { + atomic.StoreUint64((*uint64)(i), n) +} + +func (i *AtomicUint64) Get() uint64 { + return atomic.LoadUint64((*uint64)(i)) +} + +func (i *AtomicUint64) CompareAndSwap(oldval, newval uint64) (swapped bool) { + return atomic.CompareAndSwapUint64((*uint64)(i), oldval, newval) +} + +type AtomicDuration int64 + +func (d *AtomicDuration) Add(duration time.Duration) time.Duration { + return time.Duration(atomic.AddInt64((*int64)(d), int64(duration))) +} + +func (d *AtomicDuration) Set(duration time.Duration) { + atomic.StoreInt64((*int64)(d), int64(duration)) +} + +func (d *AtomicDuration) Get() time.Duration { + return time.Duration(atomic.LoadInt64((*int64)(d))) +} + +func (d *AtomicDuration) CompareAndSwap(oldval, newval time.Duration) (swapped bool) { + return atomic.CompareAndSwapInt64((*int64)(d), int64(oldval), int64(newval)) +} + +// AtomicString gives you atomic-style APIs for string, but +// it's only a convenience wrapper that uses a mutex. So, it's +// not as efficient as the rest of the atomic types. +type AtomicString struct { + mu sync.Mutex + str string +} + +func (s *AtomicString) Set(str string) { + s.mu.Lock() + s.str = str + s.mu.Unlock() +} + +func (s *AtomicString) Get() string { + s.mu.Lock() + str := s.str + s.mu.Unlock() + return str +} + +func (s *AtomicString) CompareAndSwap(oldval, newval string) (swqpped bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.str == oldval { + s.str = newval + return true + } + return false +} + +type AtomicBool int32 + +func (b *AtomicBool) Set(v bool) { + if v { + atomic.StoreInt32((*int32)(b), 1) + } else { + atomic.StoreInt32((*int32)(b), 0) + } +} + +func (b *AtomicBool) Get() bool { + return atomic.LoadInt32((*int32)(b)) == 1 +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/semaphore.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/semaphore.go new file mode 100644 index 0000000..d310da7 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/github.com/siddontang/go/sync2/semaphore.go @@ -0,0 +1,65 @@ +package sync2 + +import ( + "sync" + "sync/atomic" + "time" +) + +func NewSemaphore(initialCount int) *Semaphore { + res := &Semaphore{ + counter: int64(initialCount), + } + res.cond.L = &res.lock + return res +} + +type Semaphore struct { + lock sync.Mutex + cond sync.Cond + counter int64 +} + +func (s *Semaphore) Release() { + s.lock.Lock() + s.counter += 1 + if s.counter >= 0 { + s.cond.Signal() + } + s.lock.Unlock() +} + +func (s *Semaphore) Acquire() { + s.lock.Lock() + for s.counter < 1 { + s.cond.Wait() + } + s.counter -= 1 + s.lock.Unlock() +} + +func (s *Semaphore) AcquireTimeout(timeout time.Duration) bool { + done := make(chan bool, 1) + // Gate used to communicate between the threads and decide what the result + // is. If the main thread decides, we have timed out, otherwise we succeed. + decided := new(int32) + go func() { + s.Acquire() + if atomic.SwapInt32(decided, 1) == 0 { + done <- true + } else { + // If we already decided the result, and this thread did not win + s.Release() + } + }() + select { + case <-done: + return true + case <-time.NewTimer(timeout).C: + if atomic.SwapInt32(decided, 1) == 1 { + // The other thread already decided the result + return true + } + return false + } +} diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/LICENSE b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/LICENSE new file mode 100644 index 0000000..6a66aea --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/PATENTS b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/PATENTS new file mode 100644 index 0000000..7330990 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/context/context.go b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/context/context.go new file mode 100644 index 0000000..77b64d0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/_vendor/vendor/golang.org/x/net/context/context.go @@ -0,0 +1,447 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context defines the Context type, which carries deadlines, +// cancelation signals, and other request-scoped values across API boundaries +// and between processes. +// +// Incoming requests to a server should create a Context, and outgoing calls to +// servers should accept a Context. The chain of function calls between must +// propagate the Context, optionally replacing it with a modified copy created +// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// +// Programs that use Contexts should follow these rules to keep interfaces +// consistent across packages and enable static analysis tools to check context +// propagation: +// +// Do not store Contexts inside a struct type; instead, pass a Context +// explicitly to each function that needs it. The Context should be the first +// parameter, typically named ctx: +// +// func DoSomething(ctx context.Context, arg Arg) error { +// // ... use ctx ... +// } +// +// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// if you are unsure about which Context to use. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The same Context may be passed to functions running in different goroutines; +// Contexts are safe for simultaneous use by multiple goroutines. +// +// See http://blog.golang.org/context for example code for a server that uses +// Contexts. +package context // import "golang.org/x/net/context" + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// A Context carries a deadline, a cancelation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context interface { + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // + // WithCancel arranges for Done to be closed when cancel is called; + // WithDeadline arranges for Done to be closed when the deadline + // expires; WithTimeout arranges for Done to be closed when the timeout + // elapses. + // + // Done is provided for use in select statements: + // + // // Stream generates values with DoSomething and sends them to out + // // until DoSomething returns an error or ctx.Done is closed. + // func Stream(ctx context.Context, out <-chan Value) error { + // for { + // v, err := DoSomething(ctx) + // if err != nil { + // return err + // } + // select { + // case <-ctx.Done(): + // return ctx.Err() + // case out <- v: + // } + // } + // } + // + // See http://blog.golang.org/pipelines for more examples of how to use + // a Done channel for cancelation. + Done() <-chan struct{} + + // Err returns a non-nil error value after Done is closed. Err returns + // Canceled if the context was canceled or DeadlineExceeded if the + // context's deadline passed. No other values for Err are defined. + // After Done is closed, successive calls to Err return the same value. + Err() error + + // Value returns the value associated with this context for key, or nil + // if no value is associated with key. Successive calls to Value with + // the same key returns the same result. + // + // Use context values only for request-scoped data that transits + // processes and API boundaries, not for passing optional parameters to + // functions. + // + // A key identifies a specific value in a Context. Functions that wish + // to store values in Context typically allocate a key in a global + // variable then use that key as the argument to context.WithValue and + // Context.Value. A key can be any type that supports equality; + // packages should define keys as an unexported type to avoid + // collisions. + // + // Packages that define a Context key should provide type-safe accessors + // for the values stores using that key: + // + // // Package user defines a User type that's stored in Contexts. + // package user + // + // import "golang.org/x/net/context" + // + // // User is the type of value stored in the Contexts. + // type User struct {...} + // + // // key is an unexported type for keys defined in this package. + // // This prevents collisions with keys defined in other packages. + // type key int + // + // // userKey is the key for user.User values in Contexts. It is + // // unexported; clients use user.NewContext and user.FromContext + // // instead of using this key directly. + // var userKey key = 0 + // + // // NewContext returns a new Context that carries value u. + // func NewContext(ctx context.Context, u *User) context.Context { + // return context.WithValue(ctx, userKey, u) + // } + // + // // FromContext returns the User value stored in ctx, if any. + // func FromContext(ctx context.Context) (*User, bool) { + // u, ok := ctx.Value(userKey).(*User) + // return u, ok + // } + Value(key interface{}) interface{} +} + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = errors.New("context canceled") + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = errors.New("context deadline exceeded") + +// An emptyCtx is never canceled, has no values, and has no deadline. It is not +// struct{}, since vars of this type must have distinct addresses. +type emptyCtx int + +func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*emptyCtx) Done() <-chan struct{} { + return nil +} + +func (*emptyCtx) Err() error { + return nil +} + +func (*emptyCtx) Value(key interface{}) interface{} { + return nil +} + +func (e *emptyCtx) String() string { + switch e { + case background: + return "context.Background" + case todo: + return "context.TODO" + } + return "unknown empty Context" +} + +var ( + background = new(emptyCtx) + todo = new(emptyCtx) +) + +// Background returns a non-nil, empty Context. It is never canceled, has no +// values, and has no deadline. It is typically used by the main function, +// initialization, and tests, and as the top-level Context for incoming +// requests. +func Background() Context { + return background +} + +// TODO returns a non-nil, empty Context. Code should use context.TODO when +// it's unclear which Context to use or it is not yet available (because the +// surrounding function has not yet been extended to accept a Context +// parameter). TODO is recognized by static analysis tools that determine +// whether Contexts are propagated correctly in a program. +func TODO() Context { + return todo +} + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc func() + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := newCancelCtx(parent) + propagateCancel(parent, &c) + return &c, func() { c.cancel(true, Canceled) } +} + +// newCancelCtx returns an initialized cancelCtx. +func newCancelCtx(parent Context) cancelCtx { + return cancelCtx{ + Context: parent, + done: make(chan struct{}), + } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent Context, child canceler) { + if parent.Done() == nil { + return // parent is never canceled + } + if p, ok := parentCancelCtx(parent); ok { + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err) + } else { + if p.children == nil { + p.children = make(map[canceler]bool) + } + p.children[child] = true + } + p.mu.Unlock() + } else { + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err()) + case <-child.Done(): + } + }() + } +} + +// parentCancelCtx follows a chain of parent references until it finds a +// *cancelCtx. This function understands how each of the concrete types in this +// package represents its parent. +func parentCancelCtx(parent Context) (*cancelCtx, bool) { + for { + switch c := parent.(type) { + case *cancelCtx: + return c, true + case *timerCtx: + return &c.cancelCtx, true + case *valueCtx: + parent = c.Context + default: + return nil, false + } + } +} + +// removeChild removes a context from its parent. +func removeChild(parent Context, child canceler) { + p, ok := parentCancelCtx(parent) + if !ok { + return + } + p.mu.Lock() + if p.children != nil { + delete(p.children, child) + } + p.mu.Unlock() +} + +// A canceler is a context type that can be canceled directly. The +// implementations are *cancelCtx and *timerCtx. +type canceler interface { + cancel(removeFromParent bool, err error) + Done() <-chan struct{} +} + +// A cancelCtx can be canceled. When canceled, it also cancels any children +// that implement canceler. +type cancelCtx struct { + Context + + done chan struct{} // closed by the first cancel call. + + mu sync.Mutex + children map[canceler]bool // set to nil by the first cancel call + err error // set to non-nil by the first cancel call +} + +func (c *cancelCtx) Done() <-chan struct{} { + return c.done +} + +func (c *cancelCtx) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *cancelCtx) String() string { + return fmt.Sprintf("%v.WithCancel", c.Context) +} + +// cancel closes c.done, cancels each of c's children, and, if +// removeFromParent is true, removes c from its parent's children. +func (c *cancelCtx) cancel(removeFromParent bool, err error) { + if err == nil { + panic("context: internal error: missing cancel error") + } + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return // already canceled + } + c.err = err + close(c.done) + for child := range c.children { + // NOTE: acquiring the child's lock while holding parent's lock. + child.cancel(false, err) + } + c.children = nil + c.mu.Unlock() + + if removeFromParent { + removeChild(c.Context, c) + } +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return WithCancel(parent) + } + c := &timerCtx{ + cancelCtx: newCancelCtx(parent), + deadline: deadline, + } + propagateCancel(parent, c) + d := deadline.Sub(time.Now()) + if d <= 0 { + c.cancel(true, DeadlineExceeded) // deadline has already passed + return c, func() { c.cancel(true, Canceled) } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.timer = time.AfterFunc(d, func() { + c.cancel(true, DeadlineExceeded) + }) + } + return c, func() { c.cancel(true, Canceled) } +} + +// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to +// implement Done and Err. It implements cancel by stopping its timer then +// delegating to cancelCtx.cancel. +type timerCtx struct { + cancelCtx + timer *time.Timer // Under cancelCtx.mu. + + deadline time.Time +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { + return c.deadline, true +} + +func (c *timerCtx) String() string { + return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) +} + +func (c *timerCtx) cancel(removeFromParent bool, err error) { + c.cancelCtx.cancel(false, err) + if removeFromParent { + // Remove this timerCtx from its parent cancelCtx's children. + removeChild(c.cancelCtx.Context, c) + } + c.mu.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.mu.Unlock() +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return &valueCtx{parent, key, val} +} + +// A valueCtx carries a key-value pair. It implements Value for that key and +// delegates all other calls to the embedded Context. +type valueCtx struct { + Context + key, val interface{} +} + +func (c *valueCtx) String() string { + return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) +} + +func (c *valueCtx) Value(key interface{}) interface{} { + if c.key == key { + return c.val + } + return c.Context.Value(key) +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/canal.go b/vendor/github.com/siddontang/go-mysql/canal/canal.go new file mode 100644 index 0000000..cdeaf98 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/canal.go @@ -0,0 +1,307 @@ +package canal + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "strconv" + "strings" + "sync" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/siddontang/go-mysql/client" + "github.com/siddontang/go-mysql/dump" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/replication" + "github.com/siddontang/go-mysql/schema" + "github.com/siddontang/go/sync2" +) + +var errCanalClosed = errors.New("canal was closed") + +// Canal can sync your MySQL data into everywhere, like Elasticsearch, Redis, etc... +// MySQL must open row format for binlog +type Canal struct { + m sync.Mutex + + cfg *Config + + master *masterInfo + dumper *dump.Dumper + dumpDoneCh chan struct{} + syncer *replication.BinlogSyncer + + rsLock sync.Mutex + rsHandlers []RowsEventHandler + + connLock sync.Mutex + conn *client.Conn + + wg sync.WaitGroup + + tableLock sync.Mutex + tables map[string]*schema.Table + + quit chan struct{} + closed sync2.AtomicBool +} + +func NewCanal(cfg *Config) (*Canal, error) { + c := new(Canal) + c.cfg = cfg + c.closed.Set(false) + c.quit = make(chan struct{}) + + os.MkdirAll(cfg.DataDir, 0755) + + c.dumpDoneCh = make(chan struct{}) + c.rsHandlers = make([]RowsEventHandler, 0, 4) + c.tables = make(map[string]*schema.Table) + + var err error + if c.master, err = loadMasterInfo(c.masterInfoPath()); err != nil { + return nil, errors.Trace(err) + } else if len(c.master.Addr) != 0 && c.master.Addr != c.cfg.Addr { + log.Infof("MySQL addr %s in old master.info, but new %s, reset", c.master.Addr, c.cfg.Addr) + // may use another MySQL, reset + c.master = &masterInfo{} + } + + c.master.Addr = c.cfg.Addr + + if err := c.prepareDumper(); err != nil { + return nil, errors.Trace(err) + } + + if err = c.prepareSyncer(); err != nil { + return nil, errors.Trace(err) + } + + if err := c.checkBinlogRowFormat(); err != nil { + return nil, errors.Trace(err) + } + + return c, nil +} + +func (c *Canal) prepareDumper() error { + var err error + dumpPath := c.cfg.Dump.ExecutionPath + if len(dumpPath) == 0 { + // ignore mysqldump, use binlog only + return nil + } + + if c.dumper, err = dump.NewDumper(dumpPath, + c.cfg.Addr, c.cfg.User, c.cfg.Password); err != nil { + return errors.Trace(err) + } + + if c.dumper == nil { + //no mysqldump, use binlog only + return nil + } + + dbs := c.cfg.Dump.Databases + tables := c.cfg.Dump.Tables + tableDB := c.cfg.Dump.TableDB + + if len(tables) == 0 { + c.dumper.AddDatabases(dbs...) + } else { + c.dumper.AddTables(tableDB, tables...) + } + + for _, ignoreTable := range c.cfg.Dump.IgnoreTables { + if seps := strings.Split(ignoreTable, ","); len(seps) == 2 { + c.dumper.AddIgnoreTables(seps[0], seps[1]) + } + } + + if c.cfg.Dump.DiscardErr { + c.dumper.SetErrOut(ioutil.Discard) + } else { + c.dumper.SetErrOut(os.Stderr) + } + + return nil +} + +func (c *Canal) Start() error { + c.wg.Add(1) + go c.run() + + return nil +} + +func (c *Canal) run() error { + defer c.wg.Done() + + if err := c.tryDump(); err != nil { + log.Errorf("canal dump mysql err: %v", err) + return errors.Trace(err) + } + + close(c.dumpDoneCh) + + if err := c.startSyncBinlog(); err != nil { + if !c.isClosed() { + log.Errorf("canal start sync binlog err: %v", err) + } + return errors.Trace(err) + } + + return nil +} + +func (c *Canal) isClosed() bool { + return c.closed.Get() +} + +func (c *Canal) Close() { + log.Infof("close canal") + + c.m.Lock() + defer c.m.Unlock() + + if c.isClosed() { + return + } + + c.closed.Set(true) + + close(c.quit) + + c.connLock.Lock() + c.conn.Close() + c.conn = nil + c.connLock.Unlock() + + if c.syncer != nil { + c.syncer.Close() + c.syncer = nil + } + + c.master.Close() + + c.wg.Wait() +} + +func (c *Canal) WaitDumpDone() <-chan struct{} { + return c.dumpDoneCh +} + +func (c *Canal) GetTable(db string, table string) (*schema.Table, error) { + key := fmt.Sprintf("%s.%s", db, table) + c.tableLock.Lock() + t, ok := c.tables[key] + c.tableLock.Unlock() + + if ok { + return t, nil + } + + t, err := schema.NewTable(c, db, table) + if err != nil { + return nil, errors.Trace(err) + } + + c.tableLock.Lock() + c.tables[key] = t + c.tableLock.Unlock() + + return t, nil +} + +// Check MySQL binlog row image, must be in FULL, MINIMAL, NOBLOB +func (c *Canal) CheckBinlogRowImage(image string) error { + // need to check MySQL binlog row image? full, minimal or noblob? + // now only log + if c.cfg.Flavor == mysql.MySQLFlavor { + if res, err := c.Execute(`SHOW GLOBAL VARIABLES LIKE "binlog_row_image"`); err != nil { + return errors.Trace(err) + } else { + // MySQL has binlog row image from 5.6, so older will return empty + rowImage, _ := res.GetString(0, 1) + if rowImage != "" && !strings.EqualFold(rowImage, image) { + return errors.Errorf("MySQL uses %s binlog row image, but we want %s", rowImage, image) + } + } + } + + return nil +} + +func (c *Canal) checkBinlogRowFormat() error { + res, err := c.Execute(`SHOW GLOBAL VARIABLES LIKE "binlog_format";`) + if err != nil { + return errors.Trace(err) + } else if f, _ := res.GetString(0, 1); f != "ROW" { + return errors.Errorf("binlog must ROW format, but %s now", f) + } + + return nil +} + +func (c *Canal) prepareSyncer() error { + seps := strings.Split(c.cfg.Addr, ":") + if len(seps) != 2 { + return errors.Errorf("invalid mysql addr format %s, must host:port", c.cfg.Addr) + } + + port, err := strconv.ParseUint(seps[1], 10, 16) + if err != nil { + return errors.Trace(err) + } + + cfg := replication.BinlogSyncerConfig{ + ServerID: c.cfg.ServerID, + Flavor: c.cfg.Flavor, + Host: seps[0], + Port: uint16(port), + User: c.cfg.User, + Password: c.cfg.Password, + } + + c.syncer = replication.NewBinlogSyncer(&cfg) + + return nil +} + +func (c *Canal) masterInfoPath() string { + return path.Join(c.cfg.DataDir, "master.info") +} + +// Execute a SQL +func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err error) { + c.connLock.Lock() + defer c.connLock.Unlock() + + retryNum := 3 + for i := 0; i < retryNum; i++ { + if c.conn == nil { + c.conn, err = client.Connect(c.cfg.Addr, c.cfg.User, c.cfg.Password, "") + if err != nil { + return nil, errors.Trace(err) + } + } + + rr, err = c.conn.Execute(cmd, args...) + if err != nil && !mysql.ErrorEqual(err, mysql.ErrBadConn) { + return + } else if mysql.ErrorEqual(err, mysql.ErrBadConn) { + c.conn.Close() + c.conn = nil + continue + } else { + return + } + } + return +} + +func (c *Canal) SyncedPosition() mysql.Position { + return c.master.Pos() +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/canal_test.go b/vendor/github.com/siddontang/go-mysql/canal/canal_test.go new file mode 100644 index 0000000..b83c79e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/canal_test.go @@ -0,0 +1,94 @@ +package canal + +import ( + "flag" + "fmt" + "os" + "testing" + + "github.com/ngaut/log" + . "github.com/pingcap/check" + "github.com/siddontang/go-mysql/mysql" +) + +var testHost = flag.String("host", "127.0.0.1", "MySQL host") + +func Test(t *testing.T) { + TestingT(t) +} + +type canalTestSuite struct { + c *Canal +} + +var _ = Suite(&canalTestSuite{}) + +func (s *canalTestSuite) SetUpSuite(c *C) { + cfg := NewDefaultConfig() + cfg.Addr = fmt.Sprintf("%s:3306", *testHost) + cfg.User = "root" + cfg.Dump.ExecutionPath = "mysqldump" + cfg.Dump.TableDB = "test" + cfg.Dump.Tables = []string{"canal_test"} + + os.RemoveAll(cfg.DataDir) + + var err error + s.c, err = NewCanal(cfg) + c.Assert(err, IsNil) + + sql := ` + CREATE TABLE IF NOT EXISTS test.canal_test ( + id int AUTO_INCREMENT, + name varchar(100), + PRIMARY KEY(id) + )ENGINE=innodb; + ` + + s.execute(c, sql) + + s.execute(c, "DELETE FROM test.canal_test") + s.execute(c, "INSERT INTO test.canal_test (name) VALUES (?), (?), (?)", "a", "b", "c") + + s.execute(c, "SET GLOBAL binlog_format = 'ROW'") + + s.c.RegRowsEventHandler(&testRowsEventHandler{}) + err = s.c.Start() + c.Assert(err, IsNil) +} + +func (s *canalTestSuite) TearDownSuite(c *C) { + if s.c != nil { + s.c.Close() + s.c = nil + } +} + +func (s *canalTestSuite) execute(c *C, query string, args ...interface{}) *mysql.Result { + r, err := s.c.Execute(query, args...) + c.Assert(err, IsNil) + return r +} + +type testRowsEventHandler struct { +} + +func (h *testRowsEventHandler) Do(e *RowsEvent) error { + log.Infof("%s %v\n", e.Action, e.Rows) + return nil +} + +func (h *testRowsEventHandler) String() string { + return "testRowsEventHandler" +} + +func (s *canalTestSuite) TestCanal(c *C) { + <-s.c.WaitDumpDone() + + for i := 1; i < 10; i++ { + s.execute(c, "INSERT INTO test.canal_test (name) VALUES (?)", fmt.Sprintf("%d", i)) + } + + err := s.c.CatchMasterPos(100) + c.Assert(err, IsNil) +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/config.go b/vendor/github.com/siddontang/go-mysql/canal/config.go new file mode 100644 index 0000000..f9bb262 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/config.go @@ -0,0 +1,79 @@ +package canal + +import ( + "io/ioutil" + "math/rand" + "time" + + "github.com/BurntSushi/toml" + "github.com/juju/errors" +) + +type DumpConfig struct { + // mysqldump execution path, like mysqldump or /usr/bin/mysqldump, etc... + // If not set, ignore using mysqldump. + ExecutionPath string `toml:"mysqldump"` + + // Will override Databases, tables is in database table_db + Tables []string `toml:"tables"` + TableDB string `toml:"table_db"` + + Databases []string `toml:"dbs"` + + // Ignore table format is db.table + IgnoreTables []string `toml:"ignore_tables"` + + // If true, discard error msg, else, output to stderr + DiscardErr bool `toml:"discard_err"` +} + +type Config struct { + Addr string `toml:"addr"` + User string `toml:"user"` + Password string `toml:"password"` + + ServerID uint32 `toml:"server_id"` + Flavor string `toml:"flavor"` + DataDir string `toml:"data_dir"` + + Dump DumpConfig `toml:"dump"` +} + +func NewConfigWithFile(name string) (*Config, error) { + data, err := ioutil.ReadFile(name) + if err != nil { + return nil, errors.Trace(err) + } + + return NewConfig(string(data)) +} + +func NewConfig(data string) (*Config, error) { + var c Config + + _, err := toml.Decode(data, &c) + if err != nil { + return nil, errors.Trace(err) + } + + return &c, nil +} + +func NewDefaultConfig() *Config { + c := new(Config) + + c.Addr = "127.0.0.1:3306" + c.User = "root" + c.Password = "" + + rand.Seed(time.Now().Unix()) + c.ServerID = uint32(rand.Intn(1000)) + 1001 + + c.Flavor = "mysql" + + c.DataDir = "./var" + c.Dump.ExecutionPath = "mysqldump" + c.Dump.DiscardErr = true + + return c +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/dump.go b/vendor/github.com/siddontang/go-mysql/canal/dump.go new file mode 100644 index 0000000..4adec2a --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/dump.go @@ -0,0 +1,120 @@ +package canal + +import ( + "strconv" + "time" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/siddontang/go-mysql/dump" + "github.com/siddontang/go-mysql/schema" +) + +type dumpParseHandler struct { + c *Canal + name string + pos uint64 +} + +func (h *dumpParseHandler) BinLog(name string, pos uint64) error { + h.name = name + h.pos = pos + return nil +} + +func (h *dumpParseHandler) Data(db string, table string, values []string) error { + if h.c.isClosed() { + return errCanalClosed + } + + tableInfo, err := h.c.GetTable(db, table) + if err != nil { + log.Errorf("get %s.%s information err: %v", db, table, err) + return errors.Trace(err) + } + + vs := make([]interface{}, len(values)) + + for i, v := range values { + if v == "NULL" { + vs[i] = nil + } else if v[0] != '\'' { + if tableInfo.Columns[i].Type == schema.TYPE_NUMBER { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + log.Errorf("parse row %v at %d error %v, skip", values, i, err) + return dump.ErrSkip + } + vs[i] = n + } else if tableInfo.Columns[i].Type == schema.TYPE_FLOAT { + f, err := strconv.ParseFloat(v, 64) + if err != nil { + log.Errorf("parse row %v at %d error %v, skip", values, i, err) + return dump.ErrSkip + } + vs[i] = f + } else { + log.Errorf("parse row %v error, invalid type at %d, skip", values, i) + return dump.ErrSkip + } + } else { + vs[i] = v[1 : len(v)-1] + } + } + + events := newRowsEvent(tableInfo, InsertAction, [][]interface{}{vs}) + return h.c.travelRowsEventHandler(events) +} + +func (c *Canal) AddDumpDatabases(dbs ...string) { + if c.dumper == nil { + return + } + + c.dumper.AddDatabases(dbs...) +} + +func (c *Canal) AddDumpTables(db string, tables ...string) { + if c.dumper == nil { + return + } + + c.dumper.AddTables(db, tables...) +} + +func (c *Canal) AddDumpIgnoreTables(db string, tables ...string) { + if c.dumper == nil { + return + } + + c.dumper.AddIgnoreTables(db, tables...) +} + +func (c *Canal) tryDump() error { + if len(c.master.Name) > 0 && c.master.Position > 0 { + // we will sync with binlog name and position + log.Infof("skip dump, use last binlog replication pos (%s, %d)", c.master.Name, c.master.Position) + return nil + } + + if c.dumper == nil { + log.Info("skip dump, no mysqldump") + return nil + } + + h := &dumpParseHandler{c: c} + + start := time.Now() + log.Info("try dump MySQL and parse") + if err := c.dumper.DumpAndParse(h); err != nil { + return errors.Trace(err) + } + + log.Infof("dump MySQL and parse OK, use %0.2f seconds, start binlog replication at (%s, %d)", + time.Now().Sub(start).Seconds(), h.name, h.pos) + + c.master.Update(h.name, uint32(h.pos)) + c.master.Save(true) + + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/handler.go b/vendor/github.com/siddontang/go-mysql/canal/handler.go new file mode 100644 index 0000000..e361181 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/handler.go @@ -0,0 +1,41 @@ +package canal + +import ( + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/siddontang/go-mysql/mysql" +) + +var ( + ErrHandleInterrupted = errors.New("do handler error, interrupted") +) + +type RowsEventHandler interface { + // Handle RowsEvent, if return ErrHandleInterrupted, canal will + // stop the sync + Do(e *RowsEvent) error + String() string +} + +func (c *Canal) RegRowsEventHandler(h RowsEventHandler) { + c.rsLock.Lock() + c.rsHandlers = append(c.rsHandlers, h) + c.rsLock.Unlock() +} + +func (c *Canal) travelRowsEventHandler(e *RowsEvent) error { + c.rsLock.Lock() + defer c.rsLock.Unlock() + + var err error + for _, h := range c.rsHandlers { + if err = h.Do(e); err != nil && !mysql.ErrorEqual(err, ErrHandleInterrupted) { + log.Errorf("handle %v err: %v", h, err) + } else if mysql.ErrorEqual(err, ErrHandleInterrupted) { + log.Errorf("handle %v err, interrupted", h) + return ErrHandleInterrupted + } + + } + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/master.go b/vendor/github.com/siddontang/go-mysql/canal/master.go new file mode 100644 index 0000000..6897d05 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/master.go @@ -0,0 +1,89 @@ +package canal + +import ( + "bytes" + "os" + "sync" + "time" + + "github.com/BurntSushi/toml" + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go/ioutil2" +) + +type masterInfo struct { + Addr string `toml:"addr"` + Name string `toml:"bin_name"` + Position uint32 `toml:"bin_pos"` + + name string + + l sync.Mutex + + lastSaveTime time.Time +} + +func loadMasterInfo(name string) (*masterInfo, error) { + var m masterInfo + + m.name = name + + f, err := os.Open(name) + if err != nil && !os.IsNotExist(errors.Cause(err)) { + return nil, errors.Trace(err) + } else if os.IsNotExist(errors.Cause(err)) { + return &m, nil + } + defer f.Close() + + _, err = toml.DecodeReader(f, &m) + + return &m, err +} + +func (m *masterInfo) Save(force bool) error { + m.l.Lock() + defer m.l.Unlock() + + n := time.Now() + if !force && n.Sub(m.lastSaveTime) < time.Second { + return nil + } + + var buf bytes.Buffer + e := toml.NewEncoder(&buf) + + e.Encode(m) + + var err error + if err = ioutil2.WriteFileAtomic(m.name, buf.Bytes(), 0644); err != nil { + log.Errorf("canal save master info to file %s err %v", m.name, err) + } + + m.lastSaveTime = n + + return errors.Trace(err) +} + +func (m *masterInfo) Update(name string, pos uint32) { + m.l.Lock() + m.Name = name + m.Position = pos + m.l.Unlock() +} + +func (m *masterInfo) Pos() mysql.Position { + var pos mysql.Position + m.l.Lock() + pos.Name = m.Name + pos.Pos = m.Position + m.l.Unlock() + + return pos +} + +func (m *masterInfo) Close() { + m.Save(true) +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/rows.go b/vendor/github.com/siddontang/go-mysql/canal/rows.go new file mode 100644 index 0000000..5c5e467 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/rows.go @@ -0,0 +1,59 @@ +package canal + +import ( + "fmt" + + "github.com/juju/errors" + "github.com/siddontang/go-mysql/schema" +) + +const ( + UpdateAction = "update" + InsertAction = "insert" + DeleteAction = "delete" +) + +type RowsEvent struct { + Table *schema.Table + Action string + // changed row list + // binlog has three update event version, v0, v1 and v2. + // for v1 and v2, the rows number must be even. + // Two rows for one event, format is [before update row, after update row] + // for update v0, only one row for a event, and we don't support this version. + Rows [][]interface{} +} + +func newRowsEvent(table *schema.Table, action string, rows [][]interface{}) *RowsEvent { + e := new(RowsEvent) + + e.Table = table + e.Action = action + e.Rows = rows + + return e +} + +// Get primary keys in one row for a table, a table may use multi fields as the PK +func GetPKValues(table *schema.Table, row []interface{}) ([]interface{}, error) { + indexes := table.PKColumns + if len(indexes) == 0 { + return nil, errors.Errorf("table %s has no PK", table) + } else if len(table.Columns) != len(row) { + return nil, errors.Errorf("table %s has %d columns, but row data %v len is %d", table, + len(table.Columns), row, len(row)) + } + + values := make([]interface{}, 0, len(indexes)) + + for _, index := range indexes { + values = append(values, row[index]) + } + + return values, nil +} + +// String implements fmt.Stringer interface. +func (r *RowsEvent) String() string { + return fmt.Sprintf("%s %s %v", r.Action, r.Table, r.Rows) +} diff --git a/vendor/github.com/siddontang/go-mysql/canal/sync.go b/vendor/github.com/siddontang/go-mysql/canal/sync.go new file mode 100644 index 0000000..e76eea4 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/canal/sync.go @@ -0,0 +1,137 @@ +package canal + +import ( + "time" + + "golang.org/x/net/context" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/replication" +) + +func (c *Canal) startSyncBinlog() error { + pos := mysql.Position{c.master.Name, c.master.Position} + + log.Infof("start sync binlog at %v", pos) + + s, err := c.syncer.StartSync(pos) + if err != nil { + return errors.Errorf("start sync replication at %v error %v", pos, err) + } + + timeout := time.Second + forceSavePos := false + for { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ev, err := s.GetEvent(ctx) + cancel() + + if err == context.DeadlineExceeded { + timeout = 2 * timeout + continue + } + + if err != nil { + return errors.Trace(err) + } + + timeout = time.Second + + //next binlog pos + pos.Pos = ev.Header.LogPos + + forceSavePos = false + + // We only save position with RotateEvent and XIDEvent. + // For RowsEvent, we can't save the position until meeting XIDEvent + // which tells the whole transaction is over. + // TODO: If we meet any DDL query, we must save too. + switch e := ev.Event.(type) { + case *replication.RotateEvent: + pos.Name = string(e.NextLogName) + pos.Pos = uint32(e.Position) + // r.ev <- pos + forceSavePos = true + log.Infof("rotate binlog to %v", pos) + case *replication.RowsEvent: + // we only focus row based event + if err = c.handleRowsEvent(ev); err != nil { + log.Errorf("handle rows event error %v", err) + return errors.Trace(err) + } + continue + case *replication.XIDEvent: + // try to save the position later + default: + continue + } + + c.master.Update(pos.Name, pos.Pos) + c.master.Save(forceSavePos) + } + + return nil +} + +func (c *Canal) handleRowsEvent(e *replication.BinlogEvent) error { + ev := e.Event.(*replication.RowsEvent) + + // Caveat: table may be altered at runtime. + schema := string(ev.Table.Schema) + table := string(ev.Table.Table) + + t, err := c.GetTable(schema, table) + if err != nil { + return errors.Trace(err) + } + var action string + switch e.Header.EventType { + case replication.WRITE_ROWS_EVENTv1, replication.WRITE_ROWS_EVENTv2: + action = InsertAction + case replication.DELETE_ROWS_EVENTv1, replication.DELETE_ROWS_EVENTv2: + action = DeleteAction + case replication.UPDATE_ROWS_EVENTv1, replication.UPDATE_ROWS_EVENTv2: + action = UpdateAction + default: + return errors.Errorf("%s not supported now", e.Header.EventType) + } + events := newRowsEvent(t, action, ev.Rows) + return c.travelRowsEventHandler(events) +} + +func (c *Canal) WaitUntilPos(pos mysql.Position, timeout int) error { + if timeout <= 0 { + timeout = 60 + } + + timer := time.NewTimer(time.Duration(timeout) * time.Second) + for { + select { + case <-timer.C: + return errors.Errorf("wait position %v err", pos) + default: + curpos := c.master.Pos() + if curpos.Compare(pos) >= 0 { + return nil + } else { + time.Sleep(100 * time.Millisecond) + } + } + } + + return nil +} + +func (c *Canal) CatchMasterPos(timeout int) error { + rr, err := c.Execute("SHOW MASTER STATUS") + if err != nil { + return errors.Trace(err) + } + + name, _ := rr.GetString(0, 0) + pos, _ := rr.GetInt(0, 1) + + return c.WaitUntilPos(mysql.Position{name, uint32(pos)}, timeout) +} diff --git a/vendor/github.com/siddontang/go-mysql/client/auth.go b/vendor/github.com/siddontang/go-mysql/client/auth.go index 217ed05..85b688c 100644 --- a/vendor/github.com/siddontang/go-mysql/client/auth.go +++ b/vendor/github.com/siddontang/go-mysql/client/auth.go @@ -2,10 +2,12 @@ package client import ( "bytes" + "crypto/tls" "encoding/binary" "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/packet" ) func (c *Conn) readInitialHandshake() error { @@ -72,6 +74,12 @@ func (c *Conn) writeAuthHandshake() error { capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG + // To enable TLS / SSL + if c.TLSConfig != nil { + capability |= CLIENT_PLUGIN_AUTH + capability |= CLIENT_SSL + } + capability &= c.capability //packet length @@ -95,6 +103,9 @@ func (c *Conn) writeAuthHandshake() error { length += len(c.db) + 1 } + // mysql_native_password + null-terminated + length += 21 + 1 + c.capability = capability data := make([]byte, length+4) @@ -115,6 +126,25 @@ func (c *Conn) writeAuthHandshake() error { //use default collation id 33 here, is utf-8 data[12] = byte(DEFAULT_COLLATION_ID) + // SSL Connection Request Packet + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + if c.TLSConfig != nil { + // Send TLS / SSL request packet + if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil { + return err + } + + // Switch to TLS + tlsConn := tls.Client(c.Conn.Conn, c.TLSConfig) + if err := tlsConn.Handshake(); err != nil { + return err + } + + currentSequence := c.Sequence + c.Conn = packet.NewConn(tlsConn) + c.Sequence = currentSequence + } + //Filler [23 bytes] (all 0x00) pos := 13 + 23 @@ -122,7 +152,7 @@ func (c *Conn) writeAuthHandshake() error { if len(c.user) > 0 { pos += copy(data[pos:], c.user) } - //data[pos] = 0x00 + data[pos] = 0x00 pos++ // auth [length encoded integer] @@ -132,8 +162,13 @@ func (c *Conn) writeAuthHandshake() error { // db [null terminated string] if len(c.db) > 0 { pos += copy(data[pos:], c.db) - //data[pos] = 0x00 + data[pos] = 0x00 + pos++ } + // Assume native client during response + pos += copy(data[pos:], "mysql_native_password") + data[pos] = 0x00 + return c.WritePacket(data) } diff --git a/vendor/github.com/siddontang/go-mysql/client/client_test.go b/vendor/github.com/siddontang/go-mysql/client/client_test.go index dbf5a70..85dd8f2 100644 --- a/vendor/github.com/siddontang/go-mysql/client/client_test.go +++ b/vendor/github.com/siddontang/go-mysql/client/client_test.go @@ -1,16 +1,19 @@ package client import ( + "crypto/tls" "flag" "fmt" + "strings" "testing" - . "gopkg.in/check.v1" + . "github.com/pingcap/check" "github.com/siddontang/go-mysql/mysql" ) -var testAddr = flag.String("addr", "127.0.0.1:3306", "MySQL server address") +var testHost = flag.String("host", "127.0.0.1", "MySQL server host") +var testPort = flag.Int("port", 3306, "MySQL server port") var testUser = flag.String("user", "root", "MySQL user") var testPassword = flag.String("pass", "", "MySQL password") var testDB = flag.String("db", "test", "MySQL test database") @@ -27,7 +30,8 @@ var _ = Suite(&clientTestSuite{}) func (s *clientTestSuite) SetUpSuite(c *C) { var err error - s.c, err = Connect(*testAddr, *testUser, *testPassword, *testDB) + addr := fmt.Sprintf("%s:%d", *testHost, *testPort) + s.c, err = Connect(addr, *testUser, *testPassword, *testDB) if err != nil { c.Fatal(err) } @@ -74,6 +78,23 @@ func (s *clientTestSuite) TestConn_Ping(c *C) { c.Assert(err, IsNil) } +func (s *clientTestSuite) TestConn_TLS(c *C) { + // Verify that the provided tls.Config is used when attempting to connect to mysql. + // An empty tls.Config will result in a connection error. + addr := fmt.Sprintf("%s:%d", *testHost, *testPort) + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + c.TLSConfig = &tls.Config{} + }) + if err == nil { + c.Fatal("expected error") + } + + expected := "either ServerName or InsecureSkipVerify must be specified in the tls.Config" + if !strings.Contains(err.Error(), expected) { + c.Fatal("expected '%s' to contain '%s'", err.Error(), expected) + } +} + func (s *clientTestSuite) TestConn_Insert(c *C) { str := `insert into mixer_test_conn (id, str, f, e) values(1, "a", 3.14, "test1")` diff --git a/vendor/github.com/siddontang/go-mysql/client/conn.go b/vendor/github.com/siddontang/go-mysql/client/conn.go index da37648..54ee3f0 100644 --- a/vendor/github.com/siddontang/go-mysql/client/conn.go +++ b/vendor/github.com/siddontang/go-mysql/client/conn.go @@ -1,6 +1,7 @@ package client import ( + "crypto/tls" "fmt" "net" "strings" @@ -14,9 +15,10 @@ import ( type Conn struct { *packet.Conn - user string - password string - db string + user string + password string + db string + TLSConfig *tls.Config capability uint32 @@ -38,7 +40,8 @@ func getNetProto(addr string) string { } // Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock. -func Connect(addr string, user string, password string, dbName string) (*Conn, error) { +// Accepts a series of configuration functions as a variadic argument. +func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { proto := getNetProto(addr) c := new(Conn) @@ -57,6 +60,11 @@ func Connect(addr string, user string, password string, dbName string) (*Conn, e //use default charset here, utf-8 c.charset = DEFAULT_CHARSET + // Apply configuration functions. + for i := range options { + options[i](c) + } + if err = c.handshake(); err != nil { return nil, errors.Trace(err) } diff --git a/vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go b/vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go new file mode 100644 index 0000000..3990d00 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/cmd/go-binlogparser/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "flag" + "os" + + "github.com/siddontang/go-mysql/replication" +) + +var name = flag.String("name", "", "binlog file name") +var offset = flag.Int64("offset", 0, "parse start offset") + +func main() { + flag.Parse() + + p := replication.NewBinlogParser() + + f := func(e *replication.BinlogEvent) error { + e.Dump(os.Stdout) + return nil + } + + err := p.ParseFile(*name, *offset, f) + + if err != nil { + println(err) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go b/vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go new file mode 100644 index 0000000..4d0c1df --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/cmd/go-canal/main.go @@ -0,0 +1,101 @@ +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/siddontang/go-mysql/canal" +) + +var host = flag.String("host", "127.0.0.1", "MySQL host") +var port = flag.Int("port", 3306, "MySQL port") +var user = flag.String("user", "root", "MySQL user, must have replication privilege") +var password = flag.String("password", "", "MySQL password") + +var flavor = flag.String("flavor", "mysql", "Flavor: mysql or mariadb") + +var dataDir = flag.String("data-dir", "./var", "Path to store data, like master.info") + +var serverID = flag.Int("server-id", 101, "Unique Server ID") +var mysqldump = flag.String("mysqldump", "mysqldump", "mysqldump execution path") + +var dbs = flag.String("dbs", "test", "dump databases, seperated by comma") +var tables = flag.String("tables", "", "dump tables, seperated by comma, will overwrite dbs") +var tableDB = flag.String("table_db", "test", "database for dump tables") +var ignoreTables = flag.String("ignore_tables", "", "ignore tables, must be database.table format, separated by comma") + +func main() { + flag.Parse() + + cfg := canal.NewDefaultConfig() + cfg.Addr = fmt.Sprintf("%s:%d", *host, *port) + cfg.User = *user + cfg.Password = *password + cfg.Flavor = *flavor + cfg.DataDir = *dataDir + + cfg.ServerID = uint32(*serverID) + cfg.Dump.ExecutionPath = *mysqldump + cfg.Dump.DiscardErr = false + + c, err := canal.NewCanal(cfg) + if err != nil { + fmt.Printf("create canal err %v", err) + os.Exit(1) + } + + if len(*ignoreTables) == 0 { + subs := strings.Split(*ignoreTables, ",") + for _, sub := range subs { + if seps := strings.Split(sub, "."); len(seps) == 2 { + c.AddDumpIgnoreTables(seps[0], seps[1]) + } + } + } + + if len(*tables) > 0 && len(*tableDB) > 0 { + subs := strings.Split(*tables, ",") + c.AddDumpTables(*tableDB, subs...) + } else if len(*dbs) > 0 { + subs := strings.Split(*dbs, ",") + c.AddDumpDatabases(subs...) + } + + c.RegRowsEventHandler(&handler{}) + + err = c.Start() + if err != nil { + fmt.Printf("start canal err %V", err) + os.Exit(1) + } + + sc := make(chan os.Signal, 1) + signal.Notify(sc, + os.Kill, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT) + + <-sc + + c.Close() +} + +type handler struct { +} + +func (h *handler) Do(e *canal.RowsEvent) error { + fmt.Printf("%v\n", e) + + return nil +} + +func (h *handler) String() string { + return "TestHandler" +} diff --git a/vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go b/vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go new file mode 100644 index 0000000..5521c97 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/cmd/go-mysqlbinlog/main.go @@ -0,0 +1,76 @@ +// go-mysqlbinlog: a simple binlog tool to sync remote MySQL binlog. +// go-mysqlbinlog supports semi-sync mode like facebook mysqlbinlog. +// see http://yoshinorimatsunobu.blogspot.com/2014/04/semi-synchronous-replication-at-facebook.html +package main + +import ( + "flag" + "fmt" + "os" + + "golang.org/x/net/context" + + "github.com/juju/errors" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/replication" +) + +var host = flag.String("host", "127.0.0.1", "MySQL host") +var port = flag.Int("port", 3306, "MySQL port") +var user = flag.String("user", "root", "MySQL user, must have replication privilege") +var password = flag.String("password", "", "MySQL password") + +var flavor = flag.String("flavor", "mysql", "Flavor: mysql or mariadb") + +var file = flag.String("file", "", "Binlog filename") +var pos = flag.Int("pos", 4, "Binlog position") + +var semiSync = flag.Bool("semisync", false, "Support semi sync") +var backupPath = flag.String("backup_path", "", "backup path to store binlog files") + +var rawMode = flag.Bool("raw", false, "Use raw mode") + +func main() { + flag.Parse() + + cfg := replication.BinlogSyncerConfig{ + ServerID: 101, + Flavor: *flavor, + + Host: *host, + Port: uint16(*port), + User: *user, + Password: *password, + RawModeEanbled: *rawMode, + SemiSyncEnabled: *semiSync, + } + + b := replication.NewBinlogSyncer(&cfg) + + pos := mysql.Position{*file, uint32(*pos)} + if len(*backupPath) > 0 { + // Backup will always use RawMode. + err := b.StartBackup(*backupPath, pos, 0) + if err != nil { + fmt.Printf("Start backup error: %v\n", errors.ErrorStack(err)) + return + } + } else { + s, err := b.StartSync(pos) + if err != nil { + fmt.Printf("Start sync error: %v\n", errors.ErrorStack(err)) + return + } + + for { + e, err := s.GetEvent(context.Background()) + if err != nil { + fmt.Printf("Get event error: %v\n", errors.ErrorStack(err)) + return + } + + e.Dump(os.Stdout) + } + } + +} diff --git a/vendor/github.com/siddontang/go-mysql/cmd/go-mysqldump/main.go b/vendor/github.com/siddontang/go-mysql/cmd/go-mysqldump/main.go new file mode 100644 index 0000000..80a145a --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/cmd/go-mysqldump/main.go @@ -0,0 +1,66 @@ +package main + +import ( + "flag" + "fmt" + "os" + "strings" + + "github.com/juju/errors" + "github.com/siddontang/go-mysql/dump" +) + +var addr = flag.String("addr", "127.0.0.1:3306", "MySQL addr") +var user = flag.String("user", "root", "MySQL user") +var password = flag.String("password", "", "MySQL password") +var execution = flag.String("exec", "mysqldump", "mysqldump execution path") +var output = flag.String("o", "", "dump output, empty for stdout") + +var dbs = flag.String("dbs", "", "dump databases, seperated by comma") +var tables = flag.String("tables", "", "dump tables, seperated by comma, will overwrite dbs") +var tableDB = flag.String("table_db", "", "database for dump tables") +var ignoreTables = flag.String("ignore_tables", "", "ignore tables, must be database.table format, separated by comma") + +func main() { + flag.Parse() + + d, err := dump.NewDumper(*execution, *addr, *user, *password) + if err != nil { + fmt.Printf("Create Dumper error %v\n", errors.ErrorStack(err)) + os.Exit(1) + } + + if len(*ignoreTables) == 0 { + subs := strings.Split(*ignoreTables, ",") + for _, sub := range subs { + if seps := strings.Split(sub, "."); len(seps) == 2 { + d.AddIgnoreTables(seps[0], seps[1]) + } + } + } + + if len(*tables) > 0 && len(*tableDB) > 0 { + subs := strings.Split(*tables, ",") + d.AddTables(*tableDB, subs...) + } else if len(*dbs) > 0 { + subs := strings.Split(*dbs, ",") + d.AddDatabases(subs...) + } + + var f = os.Stdout + + if len(*output) > 0 { + f, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + fmt.Printf("Open file error %v\n", errors.ErrorStack(err)) + os.Exit(1) + } + } + + defer f.Close() + + if err = d.Dump(f); err != nil { + fmt.Printf("Dump MySQL error %v\n", errors.ErrorStack(err)) + os.Exit(1) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/docker/Makefile b/vendor/github.com/siddontang/go-mysql/docker/Makefile new file mode 100644 index 0000000..afce000 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/docker/Makefile @@ -0,0 +1,53 @@ +MYSQL_IMAGE=siddontang/mysql:latest +MARIADB_IMAGE=siddontang/mariadb:latest + +all: run_mysql1 run_mysql2 run_mysql3 run_mariadb1 run_mariadb2 run_mariadb3 + +run_mysql1: + @docker run -d -p 3306:3306 --name=mysql1 -e "GTID_MODE=on" -e "SERVER_ID=1" ${MYSQL_IMAGE} + +run_mysql2: + @docker run -d -p 3307:3306 --name=mysql2 -e "GTID_MODE=on" -e "SERVER_ID=2" ${MYSQL_IMAGE} + +run_mysql3: + @docker run -d -p 3308:3306 --name=mysql3 -e "GTID_MODE=on" -e "SERVER_ID=3" ${MYSQL_IMAGE} + +run_mariadb1: + @docker run -d -p 3316:3306 --name=mariadb1 -e "SERVER_ID=4" ${MARIADB_IMAGE} + +run_mariadb2: + @docker run -d -p 3317:3306 --name=mariadb2 -e "SERVER_ID=5" ${MARIADB_IMAGE} + +run_mariadb3: + @docker run -d -p 3318:3306 --name=mariadb3 -e "SERVER_ID=6" ${MARIADB_IMAGE} + +image_mysql: + @docker pull ${MYSQL_IMAGE} + +image_maridb: + @docker pull ${MARIADB_IMAGE} + +image: image_mysql image_maridb + +stop_mysql1: + @docker stop mysql1 + +stop_mysql2: + @docker stop mysql2 + +stop_mysql3: + @docker stop mysql3 + +stop_mariadb1: + @docker stop mariadb1 + +stop_mariadb2: + @docker stop mariadb2 + +stop_mariadb3: + @docker stop mariadb3 + +stop: stop_mysql1 stop_mysql2 stop_mysql3 stop_mariadb1 stop_mariadb2 stop_mariadb3 + +clean: + @docker rm -f mysql1 mysql2 mysql3 mariadb1 mariadb2 mariadb3 diff --git a/vendor/github.com/siddontang/go-mysql/driver/dirver_test.go b/vendor/github.com/siddontang/go-mysql/driver/dirver_test.go new file mode 100644 index 0000000..54a72cb --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/driver/dirver_test.go @@ -0,0 +1,73 @@ +package driver + +import ( + "flag" + "fmt" + "testing" + + "github.com/jmoiron/sqlx" + . "github.com/pingcap/check" +) + +// Use docker mysql to test, mysql is 3306 +var testHost = flag.String("host", "127.0.0.1", "MySQL master host") + +func TestDriver(t *testing.T) { + TestingT(t) +} + +type testDriverSuite struct { + db *sqlx.DB +} + +var _ = Suite(&testDriverSuite{}) + +func (s *testDriverSuite) SetUpSuite(c *C) { + dsn := fmt.Sprintf("root@%s:3306?test", *testHost) + + var err error + s.db, err = sqlx.Open("mysql", dsn) + c.Assert(err, IsNil) +} + +func (s *testDriverSuite) TearDownSuite(c *C) { + if s.db != nil { + s.db.Close() + } +} + +func (s *testDriverSuite) TestConn(c *C) { + var n int + err := s.db.Get(&n, "SELECT 1") + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) + + _, err = s.db.Exec("USE test") + c.Assert(err, IsNil) +} + +func (s *testDriverSuite) TestStmt(c *C) { + stmt, err := s.db.Preparex("SELECT ? + ?") + c.Assert(err, IsNil) + + var n int + err = stmt.Get(&n, 1, 1) + c.Assert(err, IsNil) + c.Assert(n, Equals, 2) + + err = stmt.Close() + c.Assert(err, IsNil) +} + +func (s *testDriverSuite) TestTransaction(c *C) { + tx, err := s.db.Beginx() + c.Assert(err, IsNil) + + var n int + err = tx.Get(&n, "SELECT 1") + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) + + err = tx.Commit() + c.Assert(err, IsNil) +} diff --git a/vendor/github.com/siddontang/go-mysql/driver/driver.go b/vendor/github.com/siddontang/go-mysql/driver/driver.go new file mode 100644 index 0000000..6263929 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/driver/driver.go @@ -0,0 +1,230 @@ +// This package implements database/sql/driver interface, +// so we can use go-mysql with database/sql +package driver + +import ( + "database/sql" + sqldriver "database/sql/driver" + "fmt" + "io" + "strings" + + "github.com/juju/errors" + "github.com/siddontang/go-mysql/client" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go/hack" +) + +type driver struct { +} + +// DSN user:password@addr[?db] +func (d driver) Open(dsn string) (sqldriver.Conn, error) { + seps := strings.Split(dsn, "@") + if len(seps) != 2 { + return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]") + } + + var user string + var password string + var addr string + var db string + + if ss := strings.Split(seps[0], ":"); len(ss) == 2 { + user, password = ss[0], ss[1] + } else if len(ss) == 1 { + user = ss[0] + } else { + return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]") + } + + if ss := strings.Split(seps[1], "?"); len(ss) == 2 { + addr, db = ss[0], ss[1] + } else if len(ss) == 1 { + addr = ss[0] + } else { + return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]") + } + + c, err := client.Connect(addr, user, password, db) + if err != nil { + return nil, err + } + + return &conn{c}, nil +} + +type conn struct { + *client.Conn +} + +func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { + st, err := c.Conn.Prepare(query) + if err != nil { + return nil, errors.Trace(err) + } + + return &stmt{st}, nil +} + +func (c *conn) Close() error { + return c.Conn.Close() +} + +func (c *conn) Begin() (sqldriver.Tx, error) { + err := c.Conn.Begin() + if err != nil { + return nil, errors.Trace(err) + } + + return &tx{c.Conn}, nil +} + +func buildArgs(args []sqldriver.Value) []interface{} { + a := make([]interface{}, len(args)) + + for i, arg := range args { + a[i] = arg + } + + return a +} + +func replyError(err error) error { + if mysql.ErrorEqual(err, mysql.ErrBadConn) { + return sqldriver.ErrBadConn + } else { + return errors.Trace(err) + } +} + +func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, error) { + a := buildArgs(args) + r, err := c.Conn.Execute(query, a...) + if err != nil { + return nil, replyError(err) + } + return &result{r}, nil +} + +func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, error) { + a := buildArgs(args) + r, err := c.Conn.Execute(query, a...) + if err != nil { + return nil, replyError(err) + } + return newRows(r.Resultset) +} + +type stmt struct { + *client.Stmt +} + +func (s *stmt) Close() error { + return s.Stmt.Close() +} + +func (s *stmt) NumInput() int { + return s.Stmt.ParamNum() +} + +func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) { + a := buildArgs(args) + r, err := s.Stmt.Execute(a...) + if err != nil { + return nil, replyError(err) + } + return &result{r}, nil +} + +func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) { + a := buildArgs(args) + r, err := s.Stmt.Execute(a...) + if err != nil { + return nil, replyError(err) + } + return newRows(r.Resultset) +} + +type tx struct { + *client.Conn +} + +func (t *tx) Commit() error { + return t.Conn.Commit() +} + +func (t *tx) Rollback() error { + return t.Conn.Rollback() +} + +type result struct { + *mysql.Result +} + +func (r *result) LastInsertId() (int64, error) { + return int64(r.Result.InsertId), nil +} + +func (r *result) RowsAffected() (int64, error) { + return int64(r.Result.AffectedRows), nil +} + +type rows struct { + *mysql.Resultset + + columns []string + step int +} + +func newRows(r *mysql.Resultset) (*rows, error) { + if r == nil { + return nil, fmt.Errorf("invalid mysql query, no correct result") + } + + rs := new(rows) + rs.Resultset = r + + rs.columns = make([]string, len(r.Fields)) + + for i, f := range r.Fields { + rs.columns[i] = hack.String(f.Name) + } + rs.step = 0 + + return rs, nil +} + +func (r *rows) Columns() []string { + return r.columns +} + +func (r *rows) Close() error { + r.step = -1 + return nil +} + +func (r *rows) Next(dest []sqldriver.Value) error { + if r.step >= r.Resultset.RowNumber() { + return io.EOF + } else if r.step == -1 { + return io.ErrUnexpectedEOF + } + + for i := 0; i < r.Resultset.ColumnNumber(); i++ { + value, err := r.Resultset.GetValue(r.step, i) + if err != nil { + return err + } + + dest[i] = sqldriver.Value(value) + } + + r.step++ + + return nil +} + +func init() { + sql.Register("mysql", driver{}) +} diff --git a/vendor/github.com/siddontang/go-mysql/dump/dump.go b/vendor/github.com/siddontang/go-mysql/dump/dump.go new file mode 100644 index 0000000..a6ff209 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/dump/dump.go @@ -0,0 +1,161 @@ +package dump + +import ( + "fmt" + "io" + "os" + "os/exec" + "strings" + + "github.com/juju/errors" +) + +// Unlick mysqldump, Dumper is designed for parsing and syning data easily. +type Dumper struct { + // mysqldump execution path, like mysqldump or /usr/bin/mysqldump, etc... + ExecutionPath string + + Addr string + User string + Password string + + // Will override Databases + Tables []string + TableDB string + + Databases []string + + IgnoreTables map[string][]string + + ErrOut io.Writer +} + +func NewDumper(executionPath string, addr string, user string, password string) (*Dumper, error) { + if len(executionPath) == 0 { + return nil, nil + } + + path, err := exec.LookPath(executionPath) + if err != nil { + return nil, errors.Trace(err) + } + + d := new(Dumper) + d.ExecutionPath = path + d.Addr = addr + d.User = user + d.Password = password + d.Tables = make([]string, 0, 16) + d.Databases = make([]string, 0, 16) + d.IgnoreTables = make(map[string][]string) + + d.ErrOut = os.Stderr + + return d, nil +} + +func (d *Dumper) SetErrOut(o io.Writer) { + d.ErrOut = o +} + +func (d *Dumper) AddDatabases(dbs ...string) { + d.Databases = append(d.Databases, dbs...) +} + +func (d *Dumper) AddTables(db string, tables ...string) { + if d.TableDB != db { + d.TableDB = db + d.Tables = d.Tables[0:0] + } + + d.Tables = append(d.Tables, tables...) +} + +func (d *Dumper) AddIgnoreTables(db string, tables ...string) { + t, _ := d.IgnoreTables[db] + t = append(t, tables...) + d.IgnoreTables[db] = t +} + +func (d *Dumper) Reset() { + d.Tables = d.Tables[0:0] + d.TableDB = "" + d.IgnoreTables = make(map[string][]string) + d.Databases = d.Databases[0:0] +} + +func (d *Dumper) Dump(w io.Writer) error { + args := make([]string, 0, 16) + + // Common args + seps := strings.Split(d.Addr, ":") + args = append(args, fmt.Sprintf("--host=%s", seps[0])) + if len(seps) > 1 { + args = append(args, fmt.Sprintf("--port=%s", seps[1])) + } + + args = append(args, fmt.Sprintf("--user=%s", d.User)) + args = append(args, fmt.Sprintf("--password=%s", d.Password)) + + args = append(args, "--master-data") + args = append(args, "--single-transaction") + args = append(args, "--skip-lock-tables") + + // Disable uncessary data + args = append(args, "--compact") + args = append(args, "--skip-opt") + args = append(args, "--quick") + + // We only care about data + args = append(args, "--no-create-info") + + // Multi row is easy for us to parse the data + args = append(args, "--skip-extended-insert") + + for db, tables := range d.IgnoreTables { + for _, table := range tables { + args = append(args, fmt.Sprintf("--ignore-table=%s.%s", db, table)) + } + } + + if len(d.Tables) == 0 && len(d.Databases) == 0 { + args = append(args, "--all-databases") + } else if len(d.Tables) == 0 { + args = append(args, "--databases") + args = append(args, d.Databases...) + } else { + args = append(args, d.TableDB) + args = append(args, d.Tables...) + + // If we only dump some tables, the dump data will not have database name + // which makes us hard to parse, so here we add it manually. + + w.Write([]byte(fmt.Sprintf("USE `%s`;\n", d.TableDB))) + } + + cmd := exec.Command(d.ExecutionPath, args...) + + cmd.Stderr = d.ErrOut + cmd.Stdout = w + + return cmd.Run() +} + +// Dump MySQL and parse immediately +func (d *Dumper) DumpAndParse(h ParseHandler) error { + r, w := io.Pipe() + + done := make(chan error, 1) + go func() { + err := Parse(r, h) + r.CloseWithError(err) + done <- err + }() + + err := d.Dump(w) + w.CloseWithError(err) + + err = <-done + + return errors.Trace(err) +} diff --git a/vendor/github.com/siddontang/go-mysql/dump/dump_test.go b/vendor/github.com/siddontang/go-mysql/dump/dump_test.go new file mode 100644 index 0000000..39e430f --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/dump/dump_test.go @@ -0,0 +1,198 @@ +package dump + +import ( + "bytes" + "flag" + "fmt" + "io/ioutil" + "os" + "testing" + + . "github.com/pingcap/check" + "github.com/siddontang/go-mysql/client" +) + +// use docker mysql for test +var host = flag.String("host", "127.0.0.1", "MySQL host") +var port = flag.Int("port", 3306, "MySQL host") + +var execution = flag.String("exec", "mysqldump", "mysqldump execution path") + +func Test(t *testing.T) { + TestingT(t) +} + +type schemaTestSuite struct { + conn *client.Conn + d *Dumper +} + +var _ = Suite(&schemaTestSuite{}) + +func (s *schemaTestSuite) SetUpSuite(c *C) { + var err error + s.conn, err = client.Connect(fmt.Sprintf("%s:%d", *host, *port), "root", "", "") + c.Assert(err, IsNil) + + s.d, err = NewDumper(*execution, fmt.Sprintf("%s:%d", *host, *port), "root", "") + c.Assert(err, IsNil) + c.Assert(s.d, NotNil) + + s.d.SetErrOut(os.Stderr) + + _, err = s.conn.Execute("CREATE DATABASE IF NOT EXISTS test1") + c.Assert(err, IsNil) + + _, err = s.conn.Execute("CREATE DATABASE IF NOT EXISTS test2") + c.Assert(err, IsNil) + + str := `CREATE TABLE IF NOT EXISTS test%d.t%d ( + id int AUTO_INCREMENT, + name varchar(256), + PRIMARY KEY(id) + ) ENGINE=INNODB` + _, err = s.conn.Execute(fmt.Sprintf(str, 1, 1)) + c.Assert(err, IsNil) + + _, err = s.conn.Execute(fmt.Sprintf(str, 2, 1)) + c.Assert(err, IsNil) + + _, err = s.conn.Execute(fmt.Sprintf(str, 1, 2)) + c.Assert(err, IsNil) + + _, err = s.conn.Execute(fmt.Sprintf(str, 2, 2)) + c.Assert(err, IsNil) + + str = `INSERT INTO test%d.t%d (name) VALUES ("a"), ("b"), ("\\"), ("''")` + + _, err = s.conn.Execute(fmt.Sprintf(str, 1, 1)) + c.Assert(err, IsNil) + + _, err = s.conn.Execute(fmt.Sprintf(str, 2, 1)) + c.Assert(err, IsNil) + + _, err = s.conn.Execute(fmt.Sprintf(str, 1, 2)) + c.Assert(err, IsNil) + + _, err = s.conn.Execute(fmt.Sprintf(str, 2, 2)) + c.Assert(err, IsNil) +} + +func (s *schemaTestSuite) TearDownSuite(c *C) { + if s.conn != nil { + _, err := s.conn.Execute("DROP DATABASE IF EXISTS test1") + c.Assert(err, IsNil) + + _, err = s.conn.Execute("DROP DATABASE IF EXISTS test2") + c.Assert(err, IsNil) + + s.conn.Close() + } +} + +func (s *schemaTestSuite) TestDump(c *C) { + // Using mysql 5.7 can't work, error: + // mysqldump: Error 1412: Table definition has changed, + // please retry transaction when dumping table `test_replication` at row: 0 + // err := s.d.Dump(ioutil.Discard) + // c.Assert(err, IsNil) + + s.d.AddDatabases("test1", "test2") + + s.d.AddIgnoreTables("test1", "t2") + + err := s.d.Dump(ioutil.Discard) + c.Assert(err, IsNil) + + s.d.AddTables("test1", "t1") + + err = s.d.Dump(ioutil.Discard) + c.Assert(err, IsNil) +} + +type testParseHandler struct { +} + +func (h *testParseHandler) BinLog(name string, pos uint64) error { + return nil +} + +func (h *testParseHandler) Data(schema string, table string, values []string) error { + return nil +} + +func (s *parserTestSuite) TestParseFindTable(c *C) { + tbl := []struct { + sql string + table string + }{ + {"INSERT INTO `note` VALUES ('title', 'here is sql: INSERT INTO `table` VALUES (\\'some value\\')');", "note"}, + {"INSERT INTO `note` VALUES ('1', '2', '3');", "note"}, + {"INSERT INTO `a.b` VALUES ('1');", "a.b"}, + } + + for _, t := range tbl { + res := valuesExp.FindAllStringSubmatch(t.sql, -1)[0][1] + c.Assert(res, Equals, t.table) + } +} + +type parserTestSuite struct { +} + +var _ = Suite(&parserTestSuite{}) + +func (s *parserTestSuite) TestUnescape(c *C) { + tbl := []struct { + escaped string + expected string + }{ + {`\\n`, `\n`}, + {`\\t`, `\t`}, + {`\\"`, `\"`}, + {`\\'`, `\'`}, + {`\\0`, `\0`}, + {`\\b`, `\b`}, + {`\\Z`, `\Z`}, + {`\\r`, `\r`}, + {`abc`, `abc`}, + {`abc\`, `abc`}, + {`ab\c`, `abc`}, + {`\abc`, `abc`}, + } + + for _, t := range tbl { + unesacped := unescapeString(t.escaped) + c.Assert(unesacped, Equals, t.expected) + } +} + +func (s *schemaTestSuite) TestParse(c *C) { + var buf bytes.Buffer + + s.d.Reset() + + s.d.AddDatabases("test1", "test2") + + err := s.d.Dump(&buf) + c.Assert(err, IsNil) + + err = Parse(&buf, new(testParseHandler)) + c.Assert(err, IsNil) +} + +func (s *parserTestSuite) TestParseValue(c *C) { + str := `'abc\\',''` + values, err := parseValues(str) + c.Assert(err, IsNil) + c.Assert(values, DeepEquals, []string{`'abc\'`, `''`}) + + str = `123,'\Z#÷QÎx£. Æ‘ÇoPâÅ_\r—\\','','qn'` + values, err = parseValues(str) + c.Assert(err, IsNil) + c.Assert(values, HasLen, 4) + + str = `123,'\Z#÷QÎx£. Æ‘ÇoPâÅ_\r—\\','','qn\'` + values, err = parseValues(str) + c.Assert(err, NotNil) +} diff --git a/vendor/github.com/siddontang/go-mysql/dump/parser.go b/vendor/github.com/siddontang/go-mysql/dump/parser.go new file mode 100644 index 0000000..69f65c7 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/dump/parser.go @@ -0,0 +1,188 @@ +package dump + +import ( + "bufio" + "fmt" + "io" + "regexp" + "strconv" + + "github.com/juju/errors" + "github.com/siddontang/go-mysql/mysql" +) + +var ( + ErrSkip = errors.New("Handler error, but skipped") +) + +type ParseHandler interface { + // Parse CHANGE MASTER TO MASTER_LOG_FILE=name, MASTER_LOG_POS=pos; + BinLog(name string, pos uint64) error + + Data(schema string, table string, values []string) error +} + +var binlogExp *regexp.Regexp +var useExp *regexp.Regexp +var valuesExp *regexp.Regexp + +func init() { + binlogExp = regexp.MustCompile("^CHANGE MASTER TO MASTER_LOG_FILE='(.+)', MASTER_LOG_POS=(\\d+);") + useExp = regexp.MustCompile("^USE `(.+)`;") + valuesExp = regexp.MustCompile("^INSERT INTO `(.+?)` VALUES \\((.+)\\);$") +} + +// Parse the dump data with Dumper generate. +// It can not parse all the data formats with mysqldump outputs +func Parse(r io.Reader, h ParseHandler) error { + rb := bufio.NewReaderSize(r, 1024*16) + + var db string + var binlogParsed bool + + for { + line, err := rb.ReadString('\n') + if err != nil && err != io.EOF { + return errors.Trace(err) + } else if mysql.ErrorEqual(err, io.EOF) { + break + } + + line = line[0 : len(line)-1] + + if !binlogParsed { + if m := binlogExp.FindAllStringSubmatch(line, -1); len(m) == 1 { + name := m[0][1] + pos, err := strconv.ParseUint(m[0][2], 10, 64) + if err != nil { + return errors.Errorf("parse binlog %v err, invalid number", line) + } + + if err = h.BinLog(name, pos); err != nil && err != ErrSkip { + return errors.Trace(err) + } + + binlogParsed = true + } + } + + if m := useExp.FindAllStringSubmatch(line, -1); len(m) == 1 { + db = m[0][1] + } + + if m := valuesExp.FindAllStringSubmatch(line, -1); len(m) == 1 { + table := m[0][1] + + values, err := parseValues(m[0][2]) + if err != nil { + return errors.Errorf("parse values %v err", line) + } + + if err = h.Data(db, table, values); err != nil && err != ErrSkip { + return errors.Trace(err) + } + } + } + + return nil +} + +func parseValues(str string) ([]string, error) { + // values are seperated by comma, but we can not split using comma directly + // string is enclosed by single quote + + // a simple implementation, may be more robust later. + + values := make([]string, 0, 8) + + i := 0 + for i < len(str) { + if str[i] != '\'' { + // no string, read until comma + j := i + 1 + for ; j < len(str) && str[j] != ','; j++ { + } + values = append(values, str[i:j]) + // skip , + i = j + 1 + } else { + // read string until another single quote + j := i + 1 + + escaped := false + for j < len(str) { + if str[j] == '\\' { + // skip escaped character + j += 2 + escaped = true + continue + } else if str[j] == '\'' { + break + } else { + j++ + } + } + + if j >= len(str) { + return nil, fmt.Errorf("parse quote values error") + } + + value := str[i : j+1] + if escaped { + value = unescapeString(value) + } + values = append(values, value) + // skip ' and , + i = j + 2 + } + + // need skip blank??? + } + + return values, nil +} + +// unescapeString un-escapes the string. +// mysqldump will escape the string when dumps, +// Refer http://dev.mysql.com/doc/refman/5.7/en/string-literals.html +func unescapeString(s string) string { + i := 0 + + value := make([]byte, 0, len(s)) + for i < len(s) { + if s[i] == '\\' { + j := i + 1 + if j == len(s) { + // The last char is \, remove + break + } + + value = append(value, unescapeChar(s[j])) + i += 2 + } else { + value = append(value, s[i]) + i++ + } + } + + return string(value) +} + +func unescapeChar(ch byte) byte { + // \" \' \\ \n \0 \b \Z \r \t ==> escape to one char + switch ch { + case 'n': + ch = '\n' + case '0': + ch = 0 + case 'b': + ch = 8 + case 'Z': + ch = 26 + case 'r': + ch = '\r' + case 't': + ch = '\t' + } + return ch +} diff --git a/vendor/github.com/siddontang/go-mysql/failover/const.go b/vendor/github.com/siddontang/go-mysql/failover/const.go new file mode 100644 index 0000000..15e27d2 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/const.go @@ -0,0 +1,11 @@ +package failover + +const ( + IOThreadType = "IO_THREAD" + SQLThreadType = "SQL_THREAD" +) + +const ( + GTIDModeOn = "ON" + GTIDModeOff = "OFF" +) diff --git a/vendor/github.com/siddontang/go-mysql/failover/doc.go b/vendor/github.com/siddontang/go-mysql/failover/doc.go new file mode 100644 index 0000000..feb037d --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/doc.go @@ -0,0 +1,8 @@ +// Failover supports to promote a new master and let other slaves +// replicate from it automatically. +// +// Failover does not support monitoring whether a master is alive or not, +// and will think the master is down. +// +// This package is still in development and could not be used in production environment. +package failover diff --git a/vendor/github.com/siddontang/go-mysql/failover/failover.go b/vendor/github.com/siddontang/go-mysql/failover/failover.go new file mode 100644 index 0000000..945c046 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/failover.go @@ -0,0 +1,68 @@ +package failover + +import ( + "github.com/juju/errors" + "github.com/siddontang/go-mysql/mysql" +) + +// Failover will do below things after the master down +// 1. Elect a slave which has the most up-to-date data with old master +// 2. Promote the slave to new master +// 3. Change other slaves to the new master +// +// Limitation: +// 1, All slaves must have the same master before, Failover will check using master server id or uuid +// 2, If the failover error, the whole topology may be wrong, we must handle this error manually +// 3, Slaves must have same replication mode, all use GTID or not +// +func Failover(flavor string, slaves []*Server) ([]*Server, error) { + var h Handler + var err error + + switch flavor { + case mysql.MySQLFlavor: + h = new(MysqlGTIDHandler) + case mysql.MariaDBFlavor: + return nil, errors.Errorf("MariaDB failover is not supported now") + default: + return nil, errors.Errorf("invalid flavor %s", flavor) + } + + // First check slaves use gtid or not + if err := h.CheckGTIDMode(slaves); err != nil { + return nil, errors.Trace(err) + } + + // Stop all slave IO_THREAD and wait the relay log done + for _, slave := range slaves { + if err = h.WaitRelayLogDone(slave); err != nil { + return nil, errors.Trace(err) + } + } + + var bestSlave *Server + // Find best slave which has the most up-to-data data + if bestSlaves, err := h.FindBestSlaves(slaves); err != nil { + return nil, errors.Trace(err) + } else { + bestSlave = bestSlaves[0] + } + + // Promote the best slave to master + if err = h.Promote(bestSlave); err != nil { + return nil, errors.Trace(err) + } + + // Change master + for i := 0; i < len(slaves); i++ { + if bestSlave == slaves[i] { + continue + } + + if err = h.ChangeMasterTo(slaves[i], bestSlave); err != nil { + return nil, errors.Trace(err) + } + } + + return slaves, nil +} diff --git a/vendor/github.com/siddontang/go-mysql/failover/failover_test.go b/vendor/github.com/siddontang/go-mysql/failover/failover_test.go new file mode 100644 index 0000000..f3e4f19 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/failover_test.go @@ -0,0 +1,176 @@ +package failover + +import ( + "flag" + "fmt" + "testing" + + . "github.com/pingcap/check" +) + +// We will use go-mysql docker to test +// go-mysql docker will build mysql 1-3 instances +var host = flag.String("host", "127.0.0.1", "go-mysql docker container address") +var enable_failover_test = flag.Bool("test-failover", false, "enable test failover") + +func Test(t *testing.T) { + TestingT(t) +} + +type failoverTestSuite struct { + s []*Server +} + +var _ = Suite(&failoverTestSuite{}) + +func (s *failoverTestSuite) SetUpSuite(c *C) { + if !*enable_failover_test { + c.Skip("skip test failover") + } + + ports := []int{3306, 3307, 3308, 3316, 3317, 3318} + + s.s = make([]*Server, len(ports)) + + for i := 0; i < len(ports); i++ { + s.s[i] = NewServer(fmt.Sprintf("%s:%d", *host, ports[i]), User{"root", ""}, User{"root", ""}) + } + + var err error + for i := 0; i < len(ports); i++ { + err = s.s[i].StopSlave() + c.Assert(err, IsNil) + + err = s.s[i].ResetSlaveALL() + c.Assert(err, IsNil) + + _, err = s.s[i].Execute(`SET GLOBAL BINLOG_FORMAT = "ROW"`) + c.Assert(err, IsNil) + + _, err = s.s[i].Execute("DROP TABLE IF EXISTS test.go_mysql_test") + c.Assert(err, IsNil) + + _, err = s.s[i].Execute("CREATE TABLE IF NOT EXISTS test.go_mysql_test (id INT AUTO_INCREMENT, name VARCHAR(256), PRIMARY KEY(id)) engine=innodb") + c.Assert(err, IsNil) + + err = s.s[i].ResetMaster() + c.Assert(err, IsNil) + } +} + +func (s *failoverTestSuite) TearDownSuite(c *C) { +} + +func (s *failoverTestSuite) TestMysqlFailover(c *C) { + h := new(MysqlGTIDHandler) + + m := s.s[0] + s1 := s.s[1] + s2 := s.s[2] + + s.testFailover(c, h, m, s1, s2) +} + +func (s *failoverTestSuite) TestMariadbFailover(c *C) { + h := new(MariadbGTIDHandler) + + for i := 3; i <= 5; i++ { + _, err := s.s[i].Execute("SET GLOBAL gtid_slave_pos = ''") + c.Assert(err, IsNil) + } + + m := s.s[3] + s1 := s.s[4] + s2 := s.s[5] + + s.testFailover(c, h, m, s1, s2) +} + +func (s *failoverTestSuite) testFailover(c *C, h Handler, m *Server, s1 *Server, s2 *Server) { + var err error + err = h.ChangeMasterTo(s1, m) + c.Assert(err, IsNil) + + err = h.ChangeMasterTo(s2, m) + c.Assert(err, IsNil) + + id := s.checkInsert(c, m, "a") + + err = h.WaitCatchMaster(s1, m) + c.Assert(err, IsNil) + + err = h.WaitCatchMaster(s2, m) + c.Assert(err, IsNil) + + s.checkSelect(c, s1, id, "a") + s.checkSelect(c, s2, id, "a") + + err = s2.StopSlaveIOThread() + c.Assert(err, IsNil) + + id = s.checkInsert(c, m, "b") + id = s.checkInsert(c, m, "c") + + err = h.WaitCatchMaster(s1, m) + c.Assert(err, IsNil) + + s.checkSelect(c, s1, id, "c") + + best, err := h.FindBestSlaves([]*Server{s1, s2}) + c.Assert(err, IsNil) + c.Assert(best, DeepEquals, []*Server{s1}) + + // promote s1 to master + err = h.Promote(s1) + c.Assert(err, IsNil) + + // change s2 to master s1 + err = h.ChangeMasterTo(s2, s1) + c.Assert(err, IsNil) + + err = h.WaitCatchMaster(s2, s1) + c.Assert(err, IsNil) + + s.checkSelect(c, s2, id, "c") + + // change m to master s1 + err = h.ChangeMasterTo(m, s1) + c.Assert(err, IsNil) + + m, s1 = s1, m + id = s.checkInsert(c, m, "d") + + err = h.WaitCatchMaster(s1, m) + c.Assert(err, IsNil) + + err = h.WaitCatchMaster(s2, m) + c.Assert(err, IsNil) + + best, err = h.FindBestSlaves([]*Server{s1, s2}) + c.Assert(err, IsNil) + c.Assert(best, DeepEquals, []*Server{s1, s2}) + + err = s2.StopSlaveIOThread() + c.Assert(err, IsNil) + + id = s.checkInsert(c, m, "e") + err = h.WaitCatchMaster(s1, m) + + best, err = h.FindBestSlaves([]*Server{s1, s2}) + c.Assert(err, IsNil) + c.Assert(best, DeepEquals, []*Server{s1}) +} + +func (s *failoverTestSuite) checkSelect(c *C, m *Server, id uint64, name string) { + rr, err := m.Execute("SELECT name FROM test.go_mysql_test WHERE id = ?", id) + c.Assert(err, IsNil) + str, _ := rr.GetString(0, 0) + c.Assert(str, Equals, name) +} + +func (s *failoverTestSuite) checkInsert(c *C, m *Server, name string) uint64 { + r, err := m.Execute("INSERT INTO test.go_mysql_test (name) VALUES (?)", name) + c.Assert(err, IsNil) + + return r.InsertId +} diff --git a/vendor/github.com/siddontang/go-mysql/failover/handler.go b/vendor/github.com/siddontang/go-mysql/failover/handler.go new file mode 100644 index 0000000..f750767 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/handler.go @@ -0,0 +1,22 @@ +package failover + +type Handler interface { + // Promote slave s to master + Promote(s *Server) error + + // Change slave s to master m and replicate from it + ChangeMasterTo(s *Server, m *Server) error + + // Ensure all relay log done, it will stop slave IO_THREAD + // You must start slave again if you want to do replication continuatively + WaitRelayLogDone(s *Server) error + + // Wait until slave s catch all data from master m at current time + WaitCatchMaster(s *Server, m *Server) error + + // Find best slave which has the most up-to-date data from master + FindBestSlaves(slaves []*Server) ([]*Server, error) + + // Check all slaves have gtid enabled + CheckGTIDMode(slaves []*Server) error +} diff --git a/vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go b/vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go new file mode 100644 index 0000000..5c8bd64 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/mariadb_gtid_handler.go @@ -0,0 +1,142 @@ +package failover + +import ( + "fmt" + "net" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" +) + +// Limiatation +// + Multi source replication is not supported +// + Slave can not handle write transactions, so maybe readonly or strict_mode = 1 is better +type MariadbGTIDHandler struct { + Handler +} + +func (h *MariadbGTIDHandler) Promote(s *Server) error { + if err := h.WaitRelayLogDone(s); err != nil { + return errors.Trace(err) + } + + if err := s.StopSlave(); err != nil { + return errors.Trace(err) + } + + return nil +} + +func (h *MariadbGTIDHandler) FindBestSlaves(slaves []*Server) ([]*Server, error) { + bestSlaves := []*Server{} + + ps := make([]uint64, len(slaves)) + + lastIndex := -1 + var seq uint64 + + for i, slave := range slaves { + rr, err := slave.Execute("SELECT @@gtid_current_pos") + + if err != nil { + return nil, errors.Trace(err) + } + + str, _ := rr.GetString(0, 0) + if len(str) == 0 { + seq = 0 + } else { + g, err := ParseMariadbGTIDSet(str) + if err != nil { + return nil, errors.Trace(err) + } + + seq = g.(MariadbGTID).SequenceNumber + } + + ps[i] = seq + + if lastIndex == -1 { + lastIndex = i + bestSlaves = []*Server{slave} + } else { + if ps[lastIndex] < seq { + lastIndex = i + bestSlaves = []*Server{slave} + } else if ps[lastIndex] == seq { + // these two slaves have same data, + bestSlaves = append(bestSlaves, slave) + } + } + } + + return bestSlaves, nil +} + +const changeMasterToWithCurrentPos = `CHANGE MASTER TO + MASTER_HOST = "%s", MASTER_PORT = %s, + MASTER_USER = "%s", MASTER_PASSWORD = "%s", + MASTER_USE_GTID = current_pos` + +func (h *MariadbGTIDHandler) ChangeMasterTo(s *Server, m *Server) error { + if err := h.WaitRelayLogDone(s); err != nil { + return errors.Trace(err) + } + + if err := s.StopSlave(); err != nil { + return errors.Trace(err) + } + + if err := s.ResetSlave(); err != nil { + return errors.Trace(err) + } + + host, port, _ := net.SplitHostPort(m.Addr) + + if _, err := s.Execute(fmt.Sprintf(changeMasterToWithCurrentPos, + host, port, m.ReplUser.Name, m.ReplUser.Password)); err != nil { + return errors.Trace(err) + } + + if err := s.StartSlave(); err != nil { + return errors.Trace(err) + } + + return nil +} + +func (h *MariadbGTIDHandler) WaitRelayLogDone(s *Server) error { + if err := s.StopSlaveIOThread(); err != nil { + return errors.Trace(err) + } + + r, err := s.SlaveStatus() + if err != nil { + return errors.Trace(err) + } + + fname, _ := r.GetStringByName(0, "Master_Log_File") + pos, _ := r.GetIntByName(0, "Read_Master_Log_Pos") + + return s.MasterPosWait(Position{fname, uint32(pos)}, 0) +} + +func (h *MariadbGTIDHandler) WaitCatchMaster(s *Server, m *Server) error { + r, err := m.Execute("SELECT @@gtid_binlog_pos") + if err != nil { + return errors.Trace(err) + } + + pos, _ := r.GetString(0, 0) + + return h.waitUntilAfterGTID(s, pos) +} + +func (h *MariadbGTIDHandler) CheckGTIDMode(slaves []*Server) error { + return nil +} + +func (h *MariadbGTIDHandler) waitUntilAfterGTID(s *Server, pos string) error { + _, err := s.Execute(fmt.Sprintf("SELECT MASTER_GTID_WAIT('%s')", pos)) + return err +} diff --git a/vendor/github.com/siddontang/go-mysql/failover/mysql_gtid_handler.go b/vendor/github.com/siddontang/go-mysql/failover/mysql_gtid_handler.go new file mode 100644 index 0000000..322913c --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/mysql_gtid_handler.go @@ -0,0 +1,141 @@ +package failover + +import ( + "fmt" + "net" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" +) + +type MysqlGTIDHandler struct { + Handler +} + +func (h *MysqlGTIDHandler) Promote(s *Server) error { + if err := h.WaitRelayLogDone(s); err != nil { + return errors.Trace(err) + } + + if err := s.StopSlave(); err != nil { + return errors.Trace(err) + } + + return nil +} + +func (h *MysqlGTIDHandler) FindBestSlaves(slaves []*Server) ([]*Server, error) { + // MHA use Relay_Master_Log_File and Exec_Master_Log_Pos to determind which is the best slave + + bestSlaves := []*Server{} + + ps := make([]Position, len(slaves)) + + lastIndex := -1 + + for i, slave := range slaves { + pos, err := slave.FetchSlaveExecutePos() + + if err != nil { + return nil, errors.Trace(err) + } + + ps[i] = pos + + if lastIndex == -1 { + lastIndex = i + bestSlaves = []*Server{slave} + } else { + switch ps[lastIndex].Compare(pos) { + case 1: + //do nothing + case -1: + lastIndex = i + bestSlaves = []*Server{slave} + case 0: + // these two slaves have same data, + bestSlaves = append(bestSlaves, slave) + } + } + } + + return bestSlaves, nil +} + +const changeMasterToWithAuto = `CHANGE MASTER TO + MASTER_HOST = "%s", MASTER_PORT = %s, + MASTER_USER = "%s", MASTER_PASSWORD = "%s", + MASTER_AUTO_POSITION = 1` + +func (h *MysqlGTIDHandler) ChangeMasterTo(s *Server, m *Server) error { + if err := h.WaitRelayLogDone(s); err != nil { + return errors.Trace(err) + } + + if err := s.StopSlave(); err != nil { + return errors.Trace(err) + } + + if err := s.ResetSlave(); err != nil { + return errors.Trace(err) + } + + host, port, _ := net.SplitHostPort(m.Addr) + + if _, err := s.Execute(fmt.Sprintf(changeMasterToWithAuto, + host, port, m.ReplUser.Name, m.ReplUser.Password)); err != nil { + return errors.Trace(err) + } + + if err := s.StartSlave(); err != nil { + return errors.Trace(err) + } + + return nil +} + +func (h *MysqlGTIDHandler) WaitRelayLogDone(s *Server) error { + if err := s.StopSlaveIOThread(); err != nil { + return errors.Trace(err) + } + + r, err := s.SlaveStatus() + if err != nil { + return errors.Trace(err) + } + + retrieved, _ := r.GetStringByName(0, "Retrieved_Gtid_Set") + + // may only support MySQL version >= 5.6.9 + // see http://dev.mysql.com/doc/refman/5.6/en/gtid-functions.html + return h.waitUntilAfterGTIDs(s, retrieved) +} + +func (h *MysqlGTIDHandler) WaitCatchMaster(s *Server, m *Server) error { + r, err := m.MasterStatus() + if err != nil { + return errors.Trace(err) + } + + masterGTIDSet, _ := r.GetStringByName(0, "Executed_Gtid_Set") + + return h.waitUntilAfterGTIDs(s, masterGTIDSet) +} + +func (h *MysqlGTIDHandler) CheckGTIDMode(slaves []*Server) error { + for i := 0; i < len(slaves); i++ { + mode, err := slaves[i].MysqlGTIDMode() + if err != nil { + return errors.Trace(err) + } else if mode != GTIDModeOn { + return errors.Errorf("%s use not GTID mode", slaves[i].Addr) + } + } + + return nil +} + +func (h *MysqlGTIDHandler) waitUntilAfterGTIDs(s *Server, gtids string) error { + _, err := s.Execute(fmt.Sprintf("SELECT WAIT_UNTIL_SQL_THREAD_AFTER_GTIDS('%s')", gtids)) + return err +} diff --git a/vendor/github.com/siddontang/go-mysql/failover/server.go b/vendor/github.com/siddontang/go-mysql/failover/server.go new file mode 100644 index 0000000..ff87464 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/failover/server.go @@ -0,0 +1,174 @@ +package failover + +import ( + "fmt" + + "github.com/siddontang/go-mysql/client" + . "github.com/siddontang/go-mysql/mysql" +) + +type User struct { + Name string + Password string +} + +type Server struct { + Addr string + + User User + ReplUser User + + conn *client.Conn +} + +func NewServer(addr string, user User, replUser User) *Server { + s := new(Server) + + s.Addr = addr + + s.User = user + s.ReplUser = replUser + + return s +} + +func (s *Server) Close() { + if s.conn != nil { + s.conn.Close() + } +} + +func (s *Server) Execute(cmd string, args ...interface{}) (r *Result, err error) { + retryNum := 3 + for i := 0; i < retryNum; i++ { + if s.conn == nil { + s.conn, err = client.Connect(s.Addr, s.User.Name, s.User.Password, "") + if err != nil { + return nil, err + } + } + + r, err = s.conn.Execute(cmd, args...) + if err != nil && ErrorEqual(err, ErrBadConn) { + return + } else if ErrorEqual(err, ErrBadConn) { + s.conn = nil + continue + } else { + return + } + } + return +} + +func (s *Server) StartSlave() error { + _, err := s.Execute("START SLAVE") + return err +} + +func (s *Server) StopSlave() error { + _, err := s.Execute("STOP SLAVE") + return err +} + +func (s *Server) StopSlaveIOThread() error { + _, err := s.Execute("STOP SLAVE IO_THREAD") + return err +} + +func (s *Server) SlaveStatus() (*Resultset, error) { + r, err := s.Execute("SHOW SLAVE STATUS") + if err != nil { + return nil, err + } else { + return r.Resultset, nil + } +} + +func (s *Server) MasterStatus() (*Resultset, error) { + r, err := s.Execute("SHOW MASTER STATUS") + if err != nil { + return nil, err + } else { + return r.Resultset, nil + } +} + +func (s *Server) ResetSlave() error { + _, err := s.Execute("RESET SLAVE") + return err +} + +func (s *Server) ResetSlaveALL() error { + _, err := s.Execute("RESET SLAVE ALL") + return err +} + +func (s *Server) ResetMaster() error { + _, err := s.Execute("RESET MASTER") + return err +} + +func (s *Server) MysqlGTIDMode() (string, error) { + r, err := s.Execute("SELECT @@gtid_mode") + if err != nil { + return GTIDModeOff, err + } + on, _ := r.GetString(0, 0) + if on != GTIDModeOn { + return GTIDModeOff, nil + } else { + return GTIDModeOn, nil + } +} + +func (s *Server) SetReadonly(b bool) error { + var err error + if b { + _, err = s.Execute("SET GLOBAL read_only = ON") + } else { + _, err = s.Execute("SET GLOBAL read_only = OFF") + } + return err +} + +func (s *Server) LockTables() error { + _, err := s.Execute("FLUSH TABLES WITH READ LOCK") + return err +} + +func (s *Server) UnlockTables() error { + _, err := s.Execute("UNLOCK TABLES") + return err +} + +// Get current binlog filename and position read from master +func (s *Server) FetchSlaveReadPos() (Position, error) { + r, err := s.SlaveStatus() + if err != nil { + return Position{}, err + } + + fname, _ := r.GetStringByName(0, "Master_Log_File") + pos, _ := r.GetIntByName(0, "Read_Master_Log_Pos") + + return Position{fname, uint32(pos)}, nil +} + +// Get current executed binlog filename and position from master +func (s *Server) FetchSlaveExecutePos() (Position, error) { + r, err := s.SlaveStatus() + if err != nil { + return Position{}, err + } + + fname, _ := r.GetStringByName(0, "Relay_Master_Log_File") + pos, _ := r.GetIntByName(0, "Exec_Master_Log_Pos") + + return Position{fname, uint32(pos)}, nil +} + +func (s *Server) MasterPosWait(pos Position, timeout int) error { + _, err := s.Execute(fmt.Sprintf("SELECT MASTER_POS_WAIT('%s', %d, %d)", pos.Name, pos.Pos, timeout)) + return err +} diff --git a/vendor/github.com/siddontang/go-mysql/glide.lock b/vendor/github.com/siddontang/go-mysql/glide.lock new file mode 100644 index 0000000..7d71953 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/glide.lock @@ -0,0 +1,30 @@ +hash: 1a3d05afef96cd7601a004e573b128db9051eecd1b5d0a3d69d3fa1ee1a3e3b8 +updated: 2016-09-03T12:30:00.028685232+08:00 +imports: +- name: github.com/BurntSushi/toml + version: 056c9bc7be7190eaa7715723883caffa5f8fa3e4 +- name: github.com/go-sql-driver/mysql + version: 3654d25ec346ee8ce71a68431025458d52a38ac0 +- name: github.com/jmoiron/sqlx + version: 54aec3fd91a2b2129ffaca0d652b8a9223ee2d9e + subpackages: + - reflectx +- name: github.com/juju/errors + version: 6f54ff6318409d31ff16261533ce2c8381a4fd5d +- name: github.com/ngaut/log + version: cec23d3e10b016363780d894a0eb732a12c06e02 +- name: github.com/pingcap/check + version: ce8a2f822ab1e245a4eefcef2996531c79c943f1 +- name: github.com/satori/go.uuid + version: 879c5887cd475cd7864858769793b2ceb0d44feb +- name: github.com/siddontang/go + version: 354e14e6c093c661abb29fd28403b3c19cff5514 + subpackages: + - hack + - ioutil2 + - sync2 +- name: golang.org/x/net + version: 6acef71eb69611914f7a30939ea9f6e194c78172 + subpackages: + - context +testImports: [] diff --git a/vendor/github.com/siddontang/go-mysql/glide.yaml b/vendor/github.com/siddontang/go-mysql/glide.yaml new file mode 100644 index 0000000..3561648 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/glide.yaml @@ -0,0 +1,26 @@ +package: github.com/siddontang/go-mysql +import: +- package: github.com/BurntSushi/toml + version: 056c9bc7be7190eaa7715723883caffa5f8fa3e4 +- package: github.com/go-sql-driver/mysql + version: 3654d25ec346ee8ce71a68431025458d52a38ac0 +- package: github.com/jmoiron/sqlx + version: 54aec3fd91a2b2129ffaca0d652b8a9223ee2d9e + subpackages: + - reflectx +- package: github.com/juju/errors + version: 6f54ff6318409d31ff16261533ce2c8381a4fd5d +- package: github.com/ngaut/log + version: cec23d3e10b016363780d894a0eb732a12c06e02 +- package: github.com/pingcap/check + version: ce8a2f822ab1e245a4eefcef2996531c79c943f1 +- package: github.com/satori/go.uuid + version: ^1.1.0 +- package: github.com/siddontang/go + version: 354e14e6c093c661abb29fd28403b3c19cff5514 + subpackages: + - hack + - ioutil2 + - sync2 +- package: golang.org/x/net + version: 6acef71eb69611914f7a30939ea9f6e194c78172 diff --git a/vendor/github.com/siddontang/go-mysql/mysql/const.go b/vendor/github.com/siddontang/go-mysql/mysql/const.go index cd15642..2f4ab63 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/const.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/const.go @@ -119,7 +119,8 @@ const ( ) const ( - MYSQL_TYPE_NEWDECIMAL byte = iota + 0xf6 + MYSQL_TYPE_JSON byte = iota + 0xf5 + MYSQL_TYPE_NEWDECIMAL MYSQL_TYPE_ENUM MYSQL_TYPE_SET MYSQL_TYPE_TINY_BLOB diff --git a/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go b/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go index ce251ee..8fbaaab 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/mysql_test.go @@ -3,7 +3,7 @@ package mysql import ( "testing" - "gopkg.in/check.v1" + "github.com/pingcap/check" ) func Test(t *testing.T) { @@ -114,40 +114,40 @@ func (t *mysqlTestSuite) TestMysqlParseBinaryUint8(c *check.C) { func (t *mysqlTestSuite) TestMysqlParseBinaryInt16(c *check.C) { i16 := ParseBinaryInt16([]byte{1, 128}) - c.Assert(i16, check.Equals, int16(-128*256 + 1)) + c.Assert(i16, check.Equals, int16(-128*256+1)) } func (t *mysqlTestSuite) TestMysqlParseBinaryUint16(c *check.C) { u16 := ParseBinaryUint16([]byte{1, 128}) - c.Assert(u16, check.Equals, uint16(128*256 + 1)) + c.Assert(u16, check.Equals, uint16(128*256+1)) } func (t *mysqlTestSuite) TestMysqlParseBinaryInt24(c *check.C) { i32 := ParseBinaryInt24([]byte{1, 2, 128}) - c.Assert(i32, check.Equals, int32(-128*65536 + 2*256 + 1)) + c.Assert(i32, check.Equals, int32(-128*65536+2*256+1)) } func (t *mysqlTestSuite) TestMysqlParseBinaryUint24(c *check.C) { u32 := ParseBinaryUint24([]byte{1, 2, 128}) - c.Assert(u32, check.Equals, uint32(128*65536 + 2*256 + 1)) + c.Assert(u32, check.Equals, uint32(128*65536+2*256+1)) } func (t *mysqlTestSuite) TestMysqlParseBinaryInt32(c *check.C) { i32 := ParseBinaryInt32([]byte{1, 2, 3, 128}) - c.Assert(i32, check.Equals, int32(-128*16777216 + 3*65536 + 2*256 + 1)) + c.Assert(i32, check.Equals, int32(-128*16777216+3*65536+2*256+1)) } func (t *mysqlTestSuite) TestMysqlParseBinaryUint32(c *check.C) { u32 := ParseBinaryUint32([]byte{1, 2, 3, 128}) - c.Assert(u32, check.Equals, uint32(128*16777216 + 3*65536 + 2*256 + 1)) + c.Assert(u32, check.Equals, uint32(128*16777216+3*65536+2*256+1)) } func (t *mysqlTestSuite) TestMysqlParseBinaryInt64(c *check.C) { i64 := ParseBinaryInt64([]byte{1, 2, 3, 4, 5, 6, 7, 128}) - c.Assert(i64, check.Equals, -128*int64(72057594037927936) + 7*int64(281474976710656) + 6*int64(1099511627776) + 5*int64(4294967296) + 4*16777216 + 3*65536 + 2*256 + 1) + c.Assert(i64, check.Equals, -128*int64(72057594037927936)+7*int64(281474976710656)+6*int64(1099511627776)+5*int64(4294967296)+4*16777216+3*65536+2*256+1) } func (t *mysqlTestSuite) TestMysqlParseBinaryUint64(c *check.C) { u64 := ParseBinaryUint64([]byte{1, 2, 3, 4, 5, 6, 7, 128}) - c.Assert(u64, check.Equals, 128*uint64(72057594037927936) + 7*uint64(281474976710656) + 6*uint64(1099511627776) + 5*uint64(4294967296) + 4*16777216 + 3*65536 + 2*256 + 1) + c.Assert(u64, check.Equals, 128*uint64(72057594037927936)+7*uint64(281474976710656)+6*uint64(1099511627776)+5*uint64(4294967296)+4*16777216+3*65536+2*256+1) } diff --git a/vendor/github.com/siddontang/go-mysql/mysql/resultset.go b/vendor/github.com/siddontang/go-mysql/mysql/resultset.go index 1e2ade5..a50d9f8 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/resultset.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/resultset.go @@ -1,6 +1,7 @@ package mysql import ( + "fmt" "strconv" "github.com/juju/errors" @@ -141,7 +142,7 @@ func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) { continue case MYSQL_TYPE_DOUBLE: - data[i] = ParseBinaryFloat64(p[pos : pos+4]) + data[i] = ParseBinaryFloat64(p[pos : pos+8]) pos += 8 continue @@ -292,10 +293,28 @@ func (r *Resultset) GetUint(row, column int) (uint64, error) { } switch v := d.(type) { - case uint64: - return v, nil + case int: + return uint64(v), nil + case int8: + return uint64(v), nil + case int16: + return uint64(v), nil + case int32: + return uint64(v), nil case int64: return uint64(v), nil + case uint: + return uint64(v), nil + case uint8: + return uint64(v), nil + case uint16: + return uint64(v), nil + case uint32: + return uint64(v), nil + case uint64: + return uint64(v), nil + case float32: + return uint64(v), nil case float64: return uint64(v), nil case string: @@ -342,12 +361,30 @@ func (r *Resultset) GetFloat(row, column int) (float64, error) { } switch v := d.(type) { - case float64: - return v, nil - case uint64: + case int: + return float64(v), nil + case int8: + return float64(v), nil + case int16: + return float64(v), nil + case int32: return float64(v), nil case int64: return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + case float32: + return float64(v), nil + case float64: + return v, nil case string: return strconv.ParseFloat(v, 64) case []byte: @@ -378,10 +415,11 @@ func (r *Resultset) GetString(row, column int) (string, error) { return v, nil case []byte: return hack.String(v), nil - case int64: - return strconv.FormatInt(v, 10), nil - case uint64: - return strconv.FormatUint(v, 10), nil + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%d", v), nil + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 64), nil case float64: return strconv.FormatFloat(v, 'f', -1, 64), nil case nil: diff --git a/vendor/github.com/siddontang/go-mysql/mysql/util.go b/vendor/github.com/siddontang/go-mysql/mysql/util.go index fa874de..7fe41fa 100644 --- a/vendor/github.com/siddontang/go-mysql/mysql/util.go +++ b/vendor/github.com/siddontang/go-mysql/mysql/util.go @@ -314,6 +314,22 @@ func GetNetProto(addr string) string { } } +// ErrorEqual returns a boolean indicating whether err1 is equal to err2. +func ErrorEqual(err1, err2 error) bool { + e1 := errors.Cause(err1) + e2 := errors.Cause(err2) + + if e1 == e2 { + return true + } + + if e1 == nil || e2 == nil { + return e1 == e2 + } + + return e1.Error() == e2.Error() +} + var encodeRef = map[byte]byte{ '\x00': '0', '\'': '\'', diff --git a/vendor/github.com/siddontang/go-mysql/replication/backup.go b/vendor/github.com/siddontang/go-mysql/replication/backup.go index f1fda21..744c38c 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/backup.go +++ b/vendor/github.com/siddontang/go-mysql/replication/backup.go @@ -6,6 +6,8 @@ import ( "path" "time" + "golang.org/x/net/context" + "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" ) @@ -18,7 +20,8 @@ func (b *BinlogSyncer) StartBackup(backupDir string, p Position, timeout time.Du timeout = 30 * 3600 * 24 * time.Second } - b.SetRawMode(true) + // Force use raw mode + b.parser.SetRawMode(true) os.MkdirAll(backupDir, 0755) @@ -38,11 +41,15 @@ func (b *BinlogSyncer) StartBackup(backupDir string, p Position, timeout time.Du }() for { - e, err := s.GetEventTimeout(timeout) - if errors.Cause(err) == ErrGetEventTimeout { - // for backward compatibility - return ErrGetEventTimeout - } else if err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + e, err := s.GetEvent(ctx) + cancel() + + if err == context.DeadlineExceeded { + return nil + } + + if err != nil { return errors.Trace(err) } diff --git a/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go b/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go index 706fe08..e5b165c 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go +++ b/vendor/github.com/siddontang/go-mysql/replication/binlogstreamer.go @@ -1,24 +1,27 @@ package replication import ( - "time" + "golang.org/x/net/context" "github.com/juju/errors" + "github.com/ngaut/log" ) var ( - ErrGetEventTimeout = errors.New("Get event timeout, try get later") - ErrNeedSyncAgain = errors.New("Last sync error or closed, try sync and get event again") - ErrSyncClosed = errors.New("Sync was closed") + ErrNeedSyncAgain = errors.New("Last sync error or closed, try sync and get event again") + ErrSyncClosed = errors.New("Sync was closed") ) +// BinlogStreamer gets the streaming event. type BinlogStreamer struct { ch chan *BinlogEvent ech chan error err error } -func (s *BinlogStreamer) GetEvent() (*BinlogEvent, error) { +// GetEvent gets the binlog event one by one, it will block until Syncer receives any events from MySQL +// or meets a sync error. You can pass a context (like Cancel or Timeout) to break the block. +func (s *BinlogStreamer) GetEvent(ctx context.Context) (*BinlogEvent, error) { if s.err != nil { return nil, ErrNeedSyncAgain } @@ -28,23 +31,8 @@ func (s *BinlogStreamer) GetEvent() (*BinlogEvent, error) { return c, nil case s.err = <-s.ech: return nil, s.err - } -} - -// if timeout, ErrGetEventTimeout will returns -// timeout value won't be set too large, otherwise it may waste lots of memory -func (s *BinlogStreamer) GetEventTimeout(d time.Duration) (*BinlogEvent, error) { - if s.err != nil { - return nil, ErrNeedSyncAgain - } - - select { - case c := <-s.ch: - return c, nil - case s.err = <-s.ech: - return nil, s.err - case <-time.After(d): - return nil, ErrGetEventTimeout + case <-ctx.Done(): + return nil, ctx.Err() } } @@ -56,6 +44,7 @@ func (s *BinlogStreamer) closeWithError(err error) { if err == nil { err = ErrSyncClosed } + log.Errorf("close sync with err: %v", err) select { case s.ech <- err: default: @@ -65,7 +54,7 @@ func (s *BinlogStreamer) closeWithError(err error) { func newBinlogStreamer() *BinlogStreamer { s := new(BinlogStreamer) - s.ch = make(chan *BinlogEvent, 1024) + s.ch = make(chan *BinlogEvent, 10240) s.ech = make(chan error, 4) return s diff --git a/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go b/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go index bcea08a..90bf691 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go +++ b/vendor/github.com/siddontang/go-mysql/replication/binlogsyncer.go @@ -1,38 +1,62 @@ package replication import ( + "crypto/tls" "encoding/binary" "fmt" "os" "sync" "time" + "golang.org/x/net/context" + "github.com/juju/errors" - "github.com/satori/go.uuid" + "github.com/ngaut/log" "github.com/siddontang/go-mysql/client" . "github.com/siddontang/go-mysql/mysql" ) var ( - errSyncRunning = errors.New("Sync is running, must Close first") - errNotRegistered = errors.New("Syncer is not registered as a slave") + errSyncRunning = errors.New("Sync is running, must Close first") ) +// BinlogSyncerConfig is the configuration for BinlogSyncer. +type BinlogSyncerConfig struct { + // ServerID is the unique ID in cluster. + ServerID uint32 + // Flavor is "mysql" or "mariadb", if not set, use "mysql" default. + Flavor string + + // Host is for MySQL server host. + Host string + // Port is for MySQL server port. + Port uint16 + // User is for MySQL user. + User string + // Password is for MySQL password. + Password string + + // Localhost is local hostname if register salve. + // If not set, use os.Hostname() instead. + Localhost string + + // SemiSyncEnabled enables semi-sync or not. + SemiSyncEnabled bool + + // RawModeEanbled is for not parsing binlog event. + RawModeEanbled bool + + // If not nil, use the provided tls.Config to connect to the database using TLS/SSL. + TLSConfig *tls.Config +} + +// BinlogSyncer syncs binlog event from server. type BinlogSyncer struct { - m sync.Mutex + m sync.RWMutex - flavor string + cfg *BinlogSyncerConfig - c *client.Conn - serverID uint32 - - localhost string - host string - port uint16 - user string - password string - - masterID uint32 + c *client.Conn wg sync.WaitGroup @@ -40,45 +64,29 @@ type BinlogSyncer struct { nextPos Position - running bool - semiSyncEnabled bool + running bool - stopCh chan struct{} + ctx context.Context + cancel context.CancelFunc } -func NewBinlogSyncer(serverID uint32, flavor string) *BinlogSyncer { +// NewBinlogSyncer creates the BinlogSyncer with cfg. +func NewBinlogSyncer(cfg *BinlogSyncerConfig) *BinlogSyncer { + log.Infof("create BinlogSyncer with config %v", cfg) + b := new(BinlogSyncer) - b.flavor = flavor - - b.serverID = serverID - - b.masterID = 0 + b.cfg = cfg b.parser = NewBinlogParser() + b.parser.SetRawMode(b.cfg.RawModeEanbled) b.running = false - b.semiSyncEnabled = false - - b.stopCh = make(chan struct{}, 1) + b.ctx, b.cancel = context.WithCancel(context.Background()) return b } -// LocalHostname returns the hostname that register slave would register as. -func (b *BinlogSyncer) LocalHostname() string { - - if b.localhost == "" { - h, _ := os.Hostname() - return h - } - return b.localhost -} - -// SetLocalHostname set's the hostname that register salve would register as. -func (b *BinlogSyncer) SetLocalHostname(name string) { - b.localhost = name -} - +// Close closes the BinlogSyncer. func (b *BinlogSyncer) Close() { b.m.Lock() defer b.m.Unlock() @@ -87,13 +95,17 @@ func (b *BinlogSyncer) Close() { } func (b *BinlogSyncer) close() { - if b.c != nil { - b.c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if b.isClosed() { + return } - select { - case b.stopCh <- struct{}{}: - default: + log.Info("syncer is closing...") + + b.running = false + b.cancel() + + if b.c != nil { + b.c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) } b.wg.Wait() @@ -102,84 +114,28 @@ func (b *BinlogSyncer) close() { b.c.Close() } - b.running = false - b.c = nil + log.Info("syncer is closed") } -func (b *BinlogSyncer) checkExec() error { - if b.running { - return errSyncRunning - } else if b.c == nil { - return errNotRegistered +func (b *BinlogSyncer) isClosed() bool { + select { + case <-b.ctx.Done(): + return true + default: + return false } - - return nil -} - -func (b *BinlogSyncer) GetMasterUUID() (uuid.UUID, error) { - b.m.Lock() - defer b.m.Unlock() - - if err := b.checkExec(); err != nil { - return uuid.UUID{}, err - } - - if r, err := b.c.Execute("SHOW GLOBAL VARIABLES LIKE 'SERVER_UUID'"); err != nil { - return uuid.UUID{}, err - } else { - s, _ := r.GetString(0, 1) - if s == "" || s == "NONE" { - return uuid.UUID{}, nil - } else { - return uuid.FromString(s) - } - } -} - -// You must register slave at first before you do other operations -// This function will close old replication sync if exists -func (b *BinlogSyncer) RegisterSlave(host string, port uint16, user string, password string) error { - b.m.Lock() - defer b.m.Unlock() - - // first, close old replication sync - b.close() - - b.host = host - b.port = port - b.user = user - b.password = password - - err := b.registerSlave() - if err != nil { - b.close() - } - return errors.Trace(err) -} - -// If you close sync before and want to restart again, you can call this before other operations -// This function will close old replication sync if exists -func (b *BinlogSyncer) ReRegisterSlave() error { - b.m.Lock() - defer b.m.Unlock() - - if len(b.host) == 0 || len(b.user) == 0 { - return errors.Errorf("empty host and user, you must register slave before") - } - - b.close() - - err := b.registerSlave() - if err != nil { - b.close() - } - - return errors.Trace(err) } func (b *BinlogSyncer) registerSlave() error { + if b.c != nil { + b.c.Close() + } + + log.Infof("register slave for master server %s:%d", b.cfg.Host, b.cfg.Port) var err error - b.c, err = client.Connect(fmt.Sprintf("%s:%d", b.host, b.port), b.user, b.password, "") + b.c, err = client.Connect(fmt.Sprintf("%s:%d", b.cfg.Host, b.cfg.Port), b.cfg.User, b.cfg.Password, "", func(c *client.Conn) { + c.TLSConfig = b.cfg.TLSConfig + }) if err != nil { return errors.Trace(err) } @@ -210,6 +166,15 @@ func (b *BinlogSyncer) registerSlave() error { } } + if b.cfg.Flavor == MariaDBFlavor { + // Refer https://github.com/alibaba/canal/wiki/BinlogChange(MariaDB5&10) + // Tell the server that we understand GTIDs by setting our slave capability + // to MARIA_SLAVE_CAPABILITY_GTID = 4 (MariaDB >= 10.0.1). + if _, err := b.c.Execute("SET @mariadb_slave_capability=4"); err != nil { + return errors.Errorf("failed to set @mariadb_slave_capability=4: %v", err) + } + } + if err = b.writeRegisterSlaveCommand(); err != nil { return errors.Trace(err) } @@ -221,12 +186,9 @@ func (b *BinlogSyncer) registerSlave() error { return nil } -func (b *BinlogSyncer) EnableSemiSync() error { - b.m.Lock() - defer b.m.Unlock() - - if err := b.checkExec(); err != nil { - return errors.Trace(err) +func (b *BinlogSyncer) enalbeSemiSync() error { + if !b.cfg.SemiSyncEnabled { + return nil } if r, err := b.c.Execute("SHOW VARIABLES LIKE 'rpl_semi_sync_master_enabled';"); err != nil { @@ -234,20 +196,38 @@ func (b *BinlogSyncer) EnableSemiSync() error { } else { s, _ := r.GetString(0, 1) if s != "ON" { - return errors.Errorf("master does not support semi synchronous replication") + log.Errorf("master does not support semi synchronous replication, use no semi-sync") + b.cfg.SemiSyncEnabled = false + return nil } } _, err := b.c.Execute(`SET @rpl_semi_sync_slave = 1;`) if err != nil { - b.semiSyncEnabled = true + return errors.Trace(err) } - return errors.Trace(err) + + return nil +} + +func (b *BinlogSyncer) prepare() error { + if b.isClosed() { + return errors.Trace(ErrSyncClosed) + } + + if err := b.registerSlave(); err != nil { + return errors.Trace(err) + } + + if err := b.enalbeSemiSync(); err != nil { + return errors.Trace(err) + } + + return nil } func (b *BinlogSyncer) startDumpStream() *BinlogStreamer { b.running = true - b.stopCh = make(chan struct{}, 1) s := newBinlogStreamer() @@ -256,66 +236,45 @@ func (b *BinlogSyncer) startDumpStream() *BinlogStreamer { return s } +// StartSync starts syncing from the `pos` position. func (b *BinlogSyncer) StartSync(pos Position) (*BinlogStreamer, error) { + log.Infof("begin to sync binlog from position %s", pos) + b.m.Lock() defer b.m.Unlock() - if err := b.checkExec(); err != nil { - return nil, err + if b.running { + return nil, errors.Trace(errSyncRunning) } - //always start from position 4 - if pos.Pos < 4 { - pos.Pos = 4 - } - - err := b.writeBinglogDumpCommand(pos) - if err != nil { - return nil, err + if err := b.prepareSyncPos(pos); err != nil { + return nil, errors.Trace(err) } return b.startDumpStream(), nil } -func (b *BinlogSyncer) SetRawMode(mode bool) error { - b.m.Lock() - defer b.m.Unlock() - - if err := b.checkExec(); err != nil { - return errors.Trace(err) - } - - b.parser.SetRawMode(mode) - return nil -} - -func (b *BinlogSyncer) ExecuteSql(query string, args ...interface{}) (*Result, error) { - b.m.Lock() - defer b.m.Unlock() - - if err := b.checkExec(); err != nil { - return nil, err - } - - return b.c.Execute(query, args...) -} - +// StartSyncGTID starts syncing from the `gset` GTIDSet. func (b *BinlogSyncer) StartSyncGTID(gset GTIDSet) (*BinlogStreamer, error) { + log.Infof("begin to sync binlog from GTID %s", gset) + b.m.Lock() defer b.m.Unlock() - if err := b.checkExec(); err != nil { - return nil, err + if b.running { + return nil, errors.Trace(errSyncRunning) + } + + if err := b.prepare(); err != nil { + return nil, errors.Trace(err) } var err error - switch b.flavor { - case MySQLFlavor: + if b.cfg.Flavor != MariaDBFlavor { + // default use MySQL err = b.writeBinlogDumpMysqlGTIDCommand(gset) - case MariaDBFlavor: + } else { err = b.writeBinlogDumpMariadbGTIDCommand(gset) - default: - err = fmt.Errorf("invalid flavor %s", b.flavor) } if err != nil { @@ -340,7 +299,7 @@ func (b *BinlogSyncer) writeBinglogDumpCommand(p Position) error { binary.LittleEndian.PutUint16(data[pos:], BINLOG_DUMP_NEVER_STOP) pos += 2 - binary.LittleEndian.PutUint32(data[pos:], b.serverID) + binary.LittleEndian.PutUint32(data[pos:], b.cfg.ServerID) pos += 4 copy(data[pos:], p.Name) @@ -362,7 +321,7 @@ func (b *BinlogSyncer) writeBinlogDumpMysqlGTIDCommand(gset GTIDSet) error { binary.LittleEndian.PutUint16(data[pos:], 0) pos += 2 - binary.LittleEndian.PutUint32(data[pos:], b.serverID) + binary.LittleEndian.PutUint32(data[pos:], b.cfg.ServerID) pos += 4 binary.LittleEndian.PutUint32(data[pos:], uint32(len(p.Name))) @@ -389,12 +348,6 @@ func (b *BinlogSyncer) writeBinlogDumpMariadbGTIDCommand(gset GTIDSet) error { startPos := gset.String() - // Tell the server that we understand GTIDs by setting our slave capability - // to MARIA_SLAVE_CAPABILITY_GTID = 4 (MariaDB >= 10.0.1). - if _, err := b.c.Execute("SET @mariadb_slave_capability=4"); err != nil { - return errors.Errorf("failed to set @mariadb_slave_capability=4: %v", err) - } - // Set the slave_connect_state variable before issuing COM_BINLOG_DUMP to // provide the start position in GTID form. query := fmt.Sprintf("SET @slave_connect_state='%s'", startPos) @@ -414,19 +367,28 @@ func (b *BinlogSyncer) writeBinlogDumpMariadbGTIDCommand(gset GTIDSet) error { return b.writeBinglogDumpCommand(Position{"", 0}) } +// localHostname returns the hostname that register slave would register as. +func (b *BinlogSyncer) localHostname() string { + if len(b.cfg.Localhost) == 0 { + h, _ := os.Hostname() + return h + } + return b.cfg.Localhost +} + func (b *BinlogSyncer) writeRegisterSlaveCommand() error { b.c.ResetSequence() - hostname := b.LocalHostname() + hostname := b.localHostname() // This should be the name of slave host not the host we are connecting to. - data := make([]byte, 4+1+4+1+len(hostname)+1+len(b.user)+1+len(b.password)+2+4+4) + data := make([]byte, 4+1+4+1+len(hostname)+1+len(b.cfg.User)+1+len(b.cfg.Password)+2+4+4) pos := 4 data[pos] = COM_REGISTER_SLAVE pos++ - binary.LittleEndian.PutUint32(data[pos:], b.serverID) + binary.LittleEndian.PutUint32(data[pos:], b.cfg.ServerID) pos += 4 // This should be the name of slave hostname not the host we are connecting to. @@ -435,24 +397,25 @@ func (b *BinlogSyncer) writeRegisterSlaveCommand() error { n := copy(data[pos:], hostname) pos += n - data[pos] = uint8(len(b.user)) + data[pos] = uint8(len(b.cfg.User)) pos++ - n = copy(data[pos:], b.user) + n = copy(data[pos:], b.cfg.User) pos += n - data[pos] = uint8(len(b.password)) + data[pos] = uint8(len(b.cfg.Password)) pos++ - n = copy(data[pos:], b.password) + n = copy(data[pos:], b.cfg.Password) pos += n - binary.LittleEndian.PutUint16(data[pos:], b.port) + binary.LittleEndian.PutUint16(data[pos:], b.cfg.Port) pos += 2 //replication rank, not used binary.LittleEndian.PutUint32(data[pos:], 0) pos += 4 - binary.LittleEndian.PutUint32(data[pos:], b.masterID) + // master ID, 0 is OK + binary.LittleEndian.PutUint32(data[pos:], 0) return b.c.WritePacket(data) } @@ -482,6 +445,37 @@ func (b *BinlogSyncer) replySemiSyncACK(p Position) error { return errors.Trace(err) } +func (b *BinlogSyncer) retrySync() error { + b.m.Lock() + defer b.m.Unlock() + + log.Infof("begin to re-sync from %s", b.nextPos) + + b.parser.Reset() + if err := b.prepareSyncPos(b.nextPos); err != nil { + return errors.Trace(err) + } + + return nil +} + +func (b *BinlogSyncer) prepareSyncPos(pos Position) error { + // always start from position 4 + if pos.Pos < 4 { + pos.Pos = 4 + } + + if err := b.prepare(); err != nil { + return errors.Trace(err) + } + + if err := b.writeBinglogDumpCommand(pos); err != nil { + return errors.Trace(err) + } + + return nil +} + func (b *BinlogSyncer) onStream(s *BinlogStreamer) { defer func() { if e := recover(); e != nil { @@ -493,8 +487,34 @@ func (b *BinlogSyncer) onStream(s *BinlogStreamer) { for { data, err := b.c.ReadPacket() if err != nil { - s.closeWithError(err) - return + log.Error(err) + + // we meet connection error, should re-connect again with + // last nextPos we got. + if len(b.nextPos.Name) == 0 { + // we can't get the correct position, close. + s.closeWithError(err) + return + } + + // TODO: add a max retry count. + for { + select { + case <-b.ctx.Done(): + s.close() + return + case <-time.After(time.Second): + if err = b.retrySync(); err != nil { + log.Errorf("retry sync err: %v, wait 1s and retry again", err) + continue + } + } + + break + } + + // we connect the server and begin to re-sync again. + continue } switch data[0] { @@ -507,9 +527,16 @@ func (b *BinlogSyncer) onStream(s *BinlogStreamer) { err = b.c.HandleErrorPacket(data) s.closeWithError(err) return + case EOF_HEADER: + // Refer http://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html + // In the MySQL client/server protocol, EOF and OK packets serve the same purpose. + // Some users told me that they received EOF packet here, but I don't know why. + // So we only log a message and retry ReadPacket. + log.Info("receive EOF packet, retry ReadPacket") + continue default: - s.closeWithError(fmt.Errorf("invalid stream header %c", data[0])) - return + log.Errorf("invalid stream header %c", data[0]) + continue } } } @@ -519,7 +546,7 @@ func (b *BinlogSyncer) parseEvent(s *BinlogStreamer, data []byte) error { data = data[1:] needACK := false - if b.semiSyncEnabled && (data[0] == SemiSyncIndicator) { + if b.cfg.SemiSyncEnabled && (data[0] == SemiSyncIndicator) { needACK = (data[1] == 0x01) //skip semi sync header data = data[2:] @@ -530,17 +557,21 @@ func (b *BinlogSyncer) parseEvent(s *BinlogStreamer, data []byte) error { return errors.Trace(err) } - b.nextPos.Pos = e.Header.LogPos + if e.Header.LogPos > 0 { + // Some events like FormatDescriptionEvent return 0, ignore. + b.nextPos.Pos = e.Header.LogPos + } if re, ok := e.Event.(*RotateEvent); ok { b.nextPos.Name = string(re.NextLogName) b.nextPos.Pos = uint32(re.Position) + log.Infof("rotate to %s", b.nextPos) } needStop := false select { case s.ch <- e: - case <-b.stopCh: + case <-b.ctx.Done(): needStop = true } diff --git a/vendor/github.com/siddontang/go-mysql/replication/bitmap.go b/vendor/github.com/siddontang/go-mysql/replication/bitmap.go deleted file mode 100644 index 5de9938..0000000 --- a/vendor/github.com/siddontang/go-mysql/replication/bitmap.go +++ /dev/null @@ -1,38 +0,0 @@ -package replication - -var bitCountInByte = [256]uint8{ - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, -} - -func bitCount(bitmap []byte) int { - var n uint32 = 0 - - for _, b := range bitmap { - n += uint32(bitCountInByte[b]) - } - - return int(n) -} - -func bitGet(bitmap []byte, i int) byte { - return bitmap[i/8] & (1 << (uint(i) & 7)) -} - -func isNullSet(nullBitmap []byte, i int) bool { - return nullBitmap[i/8]&(1<<(uint(i)%8)) > 0 -} diff --git a/vendor/github.com/siddontang/go-mysql/replication/json_binary.go b/vendor/github.com/siddontang/go-mysql/replication/json_binary.go new file mode 100644 index 0000000..83b69d0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/replication/json_binary.go @@ -0,0 +1,479 @@ +package replication + +import ( + "encoding/json" + "fmt" + "math" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go/hack" +) + +const ( + JSONB_SMALL_OBJECT byte = iota // small JSON object + JSONB_LARGE_OBJECT // large JSON object + JSONB_SMALL_ARRAY // small JSON array + JSONB_LARGE_ARRAY // large JSON array + JSONB_LITERAL // literal (true/false/null) + JSONB_INT16 // int16 + JSONB_UINT16 // uint16 + JSONB_INT32 // int32 + JSONB_UINT32 // uint32 + JSONB_INT64 // int64 + JSONB_UINT64 // uint64 + JSONB_DOUBLE // double + JSONB_STRING // string + JSONB_OPAQUE byte = 0x0f // custom data (any MySQL data type) +) + +const ( + JSONB_NULL_LITERAL byte = 0x00 + JSONB_TRUE_LITERAL byte = 0x01 + JSONB_FALSE_LITERAL byte = 0x02 +) + +const ( + jsonbSmallOffsetSize = 2 + jsonbLargeOffsetSize = 4 + + jsonbKeyEntrySizeSmall = 2 + jsonbSmallOffsetSize + jsonbKeyEntrySizeLarge = 2 + jsonbLargeOffsetSize + + jsonbValueEntrySizeSmall = 1 + jsonbSmallOffsetSize + jsonbValueEntrySizeLarge = 1 + jsonbLargeOffsetSize +) + +func jsonbGetOffsetSize(isSmall bool) int { + if isSmall { + return jsonbSmallOffsetSize + } + + return jsonbLargeOffsetSize +} + +func jsonbGetKeyEntrySize(isSmall bool) int { + if isSmall { + return jsonbKeyEntrySizeSmall + } + + return jsonbKeyEntrySizeLarge +} + +func jsonbGetValueEntrySize(isSmall bool) int { + if isSmall { + return jsonbValueEntrySizeSmall + } + + return jsonbValueEntrySizeLarge +} + +// decodeJsonBinary decodes the JSON binary encoding data and returns +// the common JSON encoding data. +func decodeJsonBinary(data []byte) ([]byte, error) { + d := new(jsonBinaryDecoder) + + if d.isDataShort(data, 1) { + return nil, d.err + } + + v := d.decodeValue(data[0], data[1:]) + if d.err != nil { + return nil, d.err + } + + return json.Marshal(v) +} + +type jsonBinaryDecoder struct { + err error +} + +func (d *jsonBinaryDecoder) decodeValue(tp byte, data []byte) interface{} { + if d.err != nil { + return nil + } + + switch tp { + case JSONB_SMALL_OBJECT: + return d.decodeObjectOrArray(data, true, true) + case JSONB_LARGE_OBJECT: + return d.decodeObjectOrArray(data, false, true) + case JSONB_SMALL_ARRAY: + return d.decodeObjectOrArray(data, true, false) + case JSONB_LARGE_ARRAY: + return d.decodeObjectOrArray(data, false, false) + case JSONB_LITERAL: + return d.decodeLiteral(data) + case JSONB_INT16: + return d.decodeInt16(data) + case JSONB_UINT16: + return d.decodeUint16(data) + case JSONB_INT32: + return d.decodeInt32(data) + case JSONB_UINT32: + return d.decodeUint32(data) + case JSONB_INT64: + return d.decodeInt64(data) + case JSONB_UINT64: + return d.decodeUint64(data) + case JSONB_DOUBLE: + return d.decodeDouble(data) + case JSONB_STRING: + return d.decodeString(data) + case JSONB_OPAQUE: + return d.decodeOpaque(data) + default: + d.err = errors.Errorf("invalid json type %d", tp) + } + + return nil +} + +func (d *jsonBinaryDecoder) decodeObjectOrArray(data []byte, isSmall bool, isObject bool) interface{} { + offsetSize := jsonbGetOffsetSize(isSmall) + if d.isDataShort(data, 2*offsetSize) { + return nil + } + + count := d.decodeCount(data, isSmall) + size := d.decodeCount(data[offsetSize:], isSmall) + + if d.isDataShort(data, int(size)) { + return nil + } + + keyEntrySize := jsonbGetKeyEntrySize(isSmall) + valueEntrySize := jsonbGetValueEntrySize(isSmall) + + headerSize := 2*offsetSize + count*valueEntrySize + + if isObject { + headerSize += count * keyEntrySize + } + + if headerSize > size { + d.err = errors.Errorf("header size %d > size %d", headerSize, size) + return nil + } + + var keys []string + if isObject { + keys = make([]string, count) + for i := 0; i < count; i++ { + // decode key + entryOffset := 2*offsetSize + keyEntrySize*i + keyOffset := d.decodeCount(data[entryOffset:], isSmall) + keyLength := int(d.decodeUint16(data[entryOffset+offsetSize:])) + + // Key must start after value entry + if keyOffset < headerSize { + d.err = errors.Errorf("invalid key offset %d, must > %d", keyOffset, headerSize) + return nil + } + + if d.isDataShort(data, keyOffset+keyLength) { + return nil + } + + keys[i] = hack.String(data[keyOffset : keyOffset+keyLength]) + } + } + + if d.err != nil { + return nil + } + + values := make([]interface{}, count) + for i := 0; i < count; i++ { + // decode value + entryOffset := 2*offsetSize + valueEntrySize*i + if isObject { + entryOffset += keyEntrySize * count + } + + tp := data[entryOffset] + + if isInlineValue(tp, isSmall) { + values[i] = d.decodeValue(tp, data[entryOffset+1:entryOffset+valueEntrySize]) + continue + } + + valueOffset := d.decodeCount(data[entryOffset+1:], isSmall) + + if d.isDataShort(data, valueOffset) { + return nil + } + + values[i] = d.decodeValue(tp, data[valueOffset:]) + } + + if d.err != nil { + return nil + } + + if !isObject { + return values + } + + m := make(map[string]interface{}, count) + for i := 0; i < count; i++ { + m[keys[i]] = values[i] + } + + return m +} + +func isInlineValue(tp byte, isSmall bool) bool { + switch tp { + case JSONB_INT16, JSONB_UINT16, JSONB_LITERAL: + return true + case JSONB_INT32, JSONB_UINT32: + return !isSmall + } + + return false +} + +func (d *jsonBinaryDecoder) decodeLiteral(data []byte) interface{} { + if d.isDataShort(data, 1) { + return nil + } + + tp := data[0] + + switch tp { + case JSONB_NULL_LITERAL: + return nil + case JSONB_TRUE_LITERAL: + return true + case JSONB_FALSE_LITERAL: + return false + } + + d.err = errors.Errorf("invalid literal %c", tp) + + return nil +} + +func (d *jsonBinaryDecoder) isDataShort(data []byte, expected int) bool { + if d.err != nil { + return true + } + + if len(data) < expected { + d.err = errors.Errorf("data len %d < expected %d", len(data), expected) + } + + return d.err != nil +} + +func (d *jsonBinaryDecoder) decodeInt16(data []byte) int16 { + if d.isDataShort(data, 2) { + return 0 + } + + v := ParseBinaryInt16(data[0:2]) + return v +} + +func (d *jsonBinaryDecoder) decodeUint16(data []byte) uint16 { + if d.isDataShort(data, 2) { + return 0 + } + + v := ParseBinaryUint16(data[0:2]) + return v +} + +func (d *jsonBinaryDecoder) decodeInt32(data []byte) int32 { + if d.isDataShort(data, 4) { + return 0 + } + + v := ParseBinaryInt32(data[0:4]) + return v +} + +func (d *jsonBinaryDecoder) decodeUint32(data []byte) uint32 { + if d.isDataShort(data, 4) { + return 0 + } + + v := ParseBinaryUint32(data[0:4]) + return v +} + +func (d *jsonBinaryDecoder) decodeInt64(data []byte) int64 { + if d.isDataShort(data, 8) { + return 0 + } + + v := ParseBinaryInt64(data[0:8]) + return v +} + +func (d *jsonBinaryDecoder) decodeUint64(data []byte) uint64 { + if d.isDataShort(data, 8) { + return 0 + } + + v := ParseBinaryUint64(data[0:8]) + return v +} + +func (d *jsonBinaryDecoder) decodeDouble(data []byte) float64 { + if d.isDataShort(data, 8) { + return 0 + } + + v := ParseBinaryFloat64(data[0:8]) + return v +} + +func (d *jsonBinaryDecoder) decodeString(data []byte) string { + if d.err != nil { + return "" + } + + l, n := d.decodeVariableLength(data) + + if d.isDataShort(data, int(l)+n) { + return "" + } + + data = data[n:] + + v := hack.String(data[0:l]) + return v +} + +func (d *jsonBinaryDecoder) decodeOpaque(data []byte) interface{} { + if d.isDataShort(data, 1) { + return nil + } + + tp := data[0] + data = data[1:] + + l, n := d.decodeVariableLength(data) + + if d.isDataShort(data, int(l)+n) { + return nil + } + + data = data[n : int(l)+n] + + switch tp { + case MYSQL_TYPE_NEWDECIMAL: + return d.decodeDecimal(data) + case MYSQL_TYPE_TIME: + return d.decodeTime(data) + case MYSQL_TYPE_DATE, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP: + return d.decodeDateTime(data) + default: + return hack.String(data) + } + + return nil +} + +func (d *jsonBinaryDecoder) decodeDecimal(data []byte) interface{} { + precision := int(data[0]) + scale := int(data[1]) + + v, _, err := decodeDecimal(data[2:], precision, scale) + d.err = err + + return v +} + +func (d *jsonBinaryDecoder) decodeTime(data []byte) interface{} { + v := d.decodeInt64(data) + + if v == 0 { + return "00:00:00" + } + + sign := "" + if v < 0 { + sign = "-" + v = -v + } + + intPart := v >> 24 + hour := (intPart >> 12) % (1 << 10) + min := (intPart >> 6) % (1 << 6) + sec := intPart % (1 << 6) + frac := v % (1 << 24) + + return fmt.Sprintf("%s%02d:%02d:%02d.%06d", sign, hour, min, sec, frac) +} + +func (d *jsonBinaryDecoder) decodeDateTime(data []byte) interface{} { + v := d.decodeInt64(data) + if v == 0 { + return "0000-00-00 00:00:00" + } + + // handle negative? + if v < 0 { + v = -v + } + + intPart := v >> 24 + ymd := intPart >> 17 + ym := ymd >> 5 + hms := intPart % (1 << 17) + + year := ym / 13 + month := ym % 13 + day := ymd % (1 << 5) + hour := (hms >> 12) + minute := (hms >> 6) % (1 << 6) + second := hms % (1 << 6) + frac := v % (1 << 24) + + return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%06d", year, month, day, hour, minute, second, frac) + +} + +func (d *jsonBinaryDecoder) decodeCount(data []byte, isSmall bool) int { + if isSmall { + v := d.decodeUint16(data) + return int(v) + } + + return int(d.decodeUint32(data)) +} + +func (d *jsonBinaryDecoder) decodeVariableLength(data []byte) (int, int) { + // The max size for variable length is math.MaxUint32, so + // here we can use 5 bytes to save it. + maxCount := 5 + if len(data) < maxCount { + maxCount = len(data) + } + + pos := 0 + length := uint64(0) + for ; pos < maxCount; pos++ { + v := data[pos] + length = (length << 7) + uint64(v&0x7F) + + if v&0x80 == 0 { + if length > math.MaxUint32 { + d.err = errors.Errorf("variable length %d must <= %d", length, math.MaxUint32) + return 0, 0 + } + + pos += 1 + // TODO: should consider length overflow int here. + return int(length), pos + } + } + + d.err = errors.New("decode variable length failed") + + return 0, 0 +} diff --git a/vendor/github.com/siddontang/go-mysql/replication/parser.go b/vendor/github.com/siddontang/go-mysql/replication/parser.go index 4610ab9..cfc97e4 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/parser.go +++ b/vendor/github.com/siddontang/go-mysql/replication/parser.go @@ -26,6 +26,10 @@ func NewBinlogParser() *BinlogParser { return p } +func (p *BinlogParser) Reset() { + p.format = nil +} + type OnEventFunc func(*BinlogEvent) error func (p *BinlogParser) ParseFile(name string, offset int64, onEvent OnEventFunc) error { @@ -50,12 +54,11 @@ func (p *BinlogParser) ParseFile(name string, offset int64, onEvent OnEventFunc) return errors.Errorf("seek %s to %d error %v", name, offset, err) } - return p.ParseReader(f, onEvent) + return p.parseReader(f, onEvent) } -func (p *BinlogParser) ParseReader(r io.Reader, onEvent OnEventFunc) error { - p.tables = make(map[uint64]*TableMapEvent) - p.format = nil +func (p *BinlogParser) parseReader(r io.Reader, onEvent OnEventFunc) error { + p.Reset() var err error var n int64 @@ -193,12 +196,11 @@ func (p *BinlogParser) parseEvent(h *EventHeader, data []byte) (Event, error) { p.tables[te.TableID] = te } - //If MySQL restart, it may use the same table id for different tables. - //We must clear the table map before parsing new events. - //We have no better way to known whether the event is before or after restart, - //So we have to clear the table map on every rotate event. - if _, ok := e.(*RotateEvent); ok { - p.tables = make(map[uint64]*TableMapEvent) + if re, ok := e.(*RowsEvent); ok { + if (re.Flags & RowsEventStmtEndFlag) > 0 { + // Refer https://github.com/alibaba/canal/blob/38cc81b7dab29b51371096fb6763ca3a8432ffee/dbsync/src/main/java/com/taobao/tddl/dbsync/binlog/event/RowsLogEvent.java#L176 + p.tables = make(map[uint64]*TableMapEvent) + } } return e, nil diff --git a/vendor/github.com/siddontang/go-mysql/replication/parser_test.go b/vendor/github.com/siddontang/go-mysql/replication/parser_test.go index ba1c96e..52f5f1f 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/parser_test.go +++ b/vendor/github.com/siddontang/go-mysql/replication/parser_test.go @@ -1,7 +1,7 @@ package replication import ( - . "gopkg.in/check.v1" + . "github.com/pingcap/check" ) func (t *testSyncerSuite) TestIndexOutOfRange(c *C) { diff --git a/vendor/github.com/siddontang/go-mysql/replication/replication_test.go b/vendor/github.com/siddontang/go-mysql/replication/replication_test.go index 4061089..8af37be 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/replication_test.go +++ b/vendor/github.com/siddontang/go-mysql/replication/replication_test.go @@ -8,15 +8,18 @@ import ( "testing" "time" + "golang.org/x/net/context" + + . "github.com/pingcap/check" + uuid "github.com/satori/go.uuid" "github.com/siddontang/go-mysql/client" "github.com/siddontang/go-mysql/mysql" - . "gopkg.in/check.v1" ) // Use docker mysql to test, mysql is 3306, mariadb is 3316 var testHost = flag.String("host", "127.0.0.1", "MySQL master host") -var testOutputLogs = flag.Bool("out", true, "output binlog event") +var testOutputLogs = flag.Bool("out", false, "output binlog event") func TestBinLogSyncer(t *testing.T) { TestingT(t) @@ -69,15 +72,19 @@ func (t *testSyncerSuite) testSync(c *C, s *BinlogStreamer) { return } + eventCount := 0 for { - e, err := s.GetEventTimeout(2 * time.Second) - if err != nil { - if err != ErrGetEventTimeout { - c.Fatal(err) - } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + e, err := s.GetEvent(ctx) + cancel() + + if err == context.DeadlineExceeded { + eventCount += 1 return } + c.Assert(err, IsNil) + if *testOutputLogs { e.Dump(os.Stdout) os.Stdout.Sync() @@ -91,26 +98,26 @@ func (t *testSyncerSuite) testSync(c *C, s *BinlogStreamer) { str := `DROP TABLE IF EXISTS test_replication` t.testExecute(c, str) - str = `CREATE TABLE IF NOT EXISTS test_replication ( - id BIGINT(64) UNSIGNED NOT NULL AUTO_INCREMENT, - str VARCHAR(256), - f FLOAT, - d DOUBLE, - de DECIMAL(10,2), - i INT, - bi BIGINT, - e enum ("e1", "e2"), - b BIT(8), - y YEAR, - da DATE, - ts TIMESTAMP, - dt DATETIME, - tm TIME, - t TEXT, - bb BLOB, - se SET('a', 'b', 'c'), - PRIMARY KEY (id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8` + str = `CREATE TABLE test_replication ( + id BIGINT(64) UNSIGNED NOT NULL AUTO_INCREMENT, + str VARCHAR(256), + f FLOAT, + d DOUBLE, + de DECIMAL(10,2), + i INT, + bi BIGINT, + e enum ("e1", "e2"), + b BIT(8), + y YEAR, + da DATE, + ts TIMESTAMP, + dt DATETIME, + tm TIME, + t TEXT, + bb BLOB, + se SET('a', 'b', 'c'), + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8` t.testExecute(c, str) @@ -132,6 +139,97 @@ func (t *testSyncerSuite) testSync(c *C, s *BinlogStreamer) { t.testExecute(c, fmt.Sprintf(`DELETE FROM test_replication WHERE id = %d`, id)) } + // check whether we can create the table including the json field + str = `DROP TABLE IF EXISTS test_json` + t.testExecute(c, str) + + str = `CREATE TABLE test_json ( + id BIGINT(64) UNSIGNED NOT NULL AUTO_INCREMENT, + c1 JSON, + c2 DECIMAL(10, 0), + PRIMARY KEY (id) + ) ENGINE=InnoDB` + + if _, err := t.c.Execute(str); err == nil { + t.testExecute(c, `INSERT INTO test_json (c2) VALUES (1)`) + t.testExecute(c, `INSERT INTO test_json (c1, c2) VALUES ('{"key1": "value1", "key2": "value2"}', 1)`) + } + + t.testExecute(c, "DROP TABLE IF EXISTS test_json_v2") + + str = `CREATE TABLE test_json_v2 ( + id INT, + c JSON, + PRIMARY KEY (id) + ) ENGINE=InnoDB` + + if _, err := t.c.Execute(str); err == nil { + tbls := []string{ + // Refer: https://github.com/shyiko/mysql-binlog-connector-java/blob/c8e81c879710dc19941d952f9031b0a98f8b7c02/src/test/java/com/github/shyiko/mysql/binlog/event/deserialization/json/JsonBinaryValueIntegrationTest.java#L84 + // License: https://github.com/shyiko/mysql-binlog-connector-java#license + `INSERT INTO test_json_v2 VALUES (0, NULL)`, + `INSERT INTO test_json_v2 VALUES (1, '{\"a\": 2}')`, + `INSERT INTO test_json_v2 VALUES (2, '[1,2]')`, + `INSERT INTO test_json_v2 VALUES (3, '{\"a\":\"b\", \"c\":\"d\",\"ab\":\"abc\", \"bc\": [\"x\", \"y\"]}')`, + `INSERT INTO test_json_v2 VALUES (4, '[\"here\", [\"I\", \"am\"], \"!!!\"]')`, + `INSERT INTO test_json_v2 VALUES (5, '\"scalar string\"')`, + `INSERT INTO test_json_v2 VALUES (6, 'true')`, + `INSERT INTO test_json_v2 VALUES (7, 'false')`, + `INSERT INTO test_json_v2 VALUES (8, 'null')`, + `INSERT INTO test_json_v2 VALUES (9, '-1')`, + `INSERT INTO test_json_v2 VALUES (10, CAST(CAST(1 AS UNSIGNED) AS JSON))`, + `INSERT INTO test_json_v2 VALUES (11, '32767')`, + `INSERT INTO test_json_v2 VALUES (12, '32768')`, + `INSERT INTO test_json_v2 VALUES (13, '-32768')`, + `INSERT INTO test_json_v2 VALUES (14, '-32769')`, + `INSERT INTO test_json_v2 VALUES (15, '2147483647')`, + `INSERT INTO test_json_v2 VALUES (16, '2147483648')`, + `INSERT INTO test_json_v2 VALUES (17, '-2147483648')`, + `INSERT INTO test_json_v2 VALUES (18, '-2147483649')`, + `INSERT INTO test_json_v2 VALUES (19, '18446744073709551615')`, + `INSERT INTO test_json_v2 VALUES (20, '18446744073709551616')`, + `INSERT INTO test_json_v2 VALUES (21, '3.14')`, + `INSERT INTO test_json_v2 VALUES (22, '{}')`, + `INSERT INTO test_json_v2 VALUES (23, '[]')`, + `INSERT INTO test_json_v2 VALUES (24, CAST(CAST('2015-01-15 23:24:25' AS DATETIME) AS JSON))`, + `INSERT INTO test_json_v2 VALUES (25, CAST(CAST('23:24:25' AS TIME) AS JSON))`, + `INSERT INTO test_json_v2 VALUES (125, CAST(CAST('23:24:25.12' AS TIME(3)) AS JSON))`, + `INSERT INTO test_json_v2 VALUES (225, CAST(CAST('23:24:25.0237' AS TIME(3)) AS JSON))`, + `INSERT INTO test_json_v2 VALUES (26, CAST(CAST('2015-01-15' AS DATE) AS JSON))`, + `INSERT INTO test_json_v2 VALUES (27, CAST(TIMESTAMP'2015-01-15 23:24:25' AS JSON))`, + `INSERT INTO test_json_v2 VALUES (127, CAST(TIMESTAMP'2015-01-15 23:24:25.12' AS JSON))`, + `INSERT INTO test_json_v2 VALUES (227, CAST(TIMESTAMP'2015-01-15 23:24:25.0237' AS JSON))`, + `INSERT INTO test_json_v2 VALUES (327, CAST(UNIX_TIMESTAMP('2015-01-15 23:24:25') AS JSON))`, + `INSERT INTO test_json_v2 VALUES (28, CAST(ST_GeomFromText('POINT(1 1)') AS JSON))`, + `INSERT INTO test_json_v2 VALUES (29, CAST('[]' AS CHAR CHARACTER SET 'ascii'))`, + // TODO: 30 and 31 are BIT type from JSON_TYPE, may support later. + `INSERT INTO test_json_v2 VALUES (30, CAST(x'cafe' AS JSON))`, + `INSERT INTO test_json_v2 VALUES (31, CAST(x'cafebabe' AS JSON))`, + `INSERT INTO test_json_v2 VALUES (100, CONCAT('{\"', REPEAT('a', 64 * 1024 - 1), '\":123}'))`, + } + + for _, query := range tbls { + t.testExecute(c, query) + } + + // If MySQL supports JSON, it must supports GEOMETRY. + t.testExecute(c, "DROP TABLE IF EXISTS test_geo") + + str = `CREATE TABLE test_geo (g GEOMETRY)` + _, err = t.c.Execute(str) + c.Assert(err, IsNil) + + tbls = []string{ + `INSERT INTO test_geo VALUES (POINT(1, 1))`, + `INSERT INTO test_geo VALUES (LINESTRING(POINT(0,0), POINT(1,1), POINT(2,2)))`, + // TODO: add more geometry tests + } + + for _, query := range tbls { + t.testExecute(c, query) + } + } + t.wg.Wait() } @@ -164,10 +262,16 @@ func (t *testSyncerSuite) setupTest(c *C, flavor string) { t.b.Close() } - t.b = NewBinlogSyncer(100, flavor) + cfg := BinlogSyncerConfig{ + ServerID: 100, + Flavor: flavor, + Host: *testHost, + Port: port, + User: "root", + Password: "", + } - err = t.b.RegisterSlave(*testHost, port, "root", "") - c.Assert(err, IsNil) + t.b = NewBinlogSyncer(&cfg) } func (t *testSyncerSuite) testPositionSync(c *C) { @@ -180,6 +284,11 @@ func (t *testSyncerSuite) testPositionSync(c *C) { s, err := t.b.StartSync(mysql.Position{binFile, uint32(binPos)}) c.Assert(err, IsNil) + // Test re-sync. + time.Sleep(100 * time.Millisecond) + t.b.c.SetReadDeadline(time.Now().Add(time.Millisecond)) + time.Sleep(100 * time.Millisecond) + t.testSync(c, s) } @@ -198,9 +307,15 @@ func (t *testSyncerSuite) TestMysqlGTIDSync(c *C) { c.Skip("GTID mode is not ON") } - masterUuid, err := t.b.GetMasterUUID() + r, err = t.c.Execute("SHOW GLOBAL VARIABLES LIKE 'SERVER_UUID'") c.Assert(err, IsNil) + var masterUuid uuid.UUID + if s, _ := r.GetString(0, 1); len(s) > 0 && s != "NONE" { + masterUuid, err = uuid.FromString(s) + c.Assert(err, IsNil) + } + set, _ := mysql.ParseMysqlGTIDSet(fmt.Sprintf("%s:%d-%d", masterUuid.String(), 1, 2)) s, err := t.b.StartSyncGTID(set) @@ -234,10 +349,7 @@ func (t *testSyncerSuite) TestMariadbGTIDSync(c *C) { func (t *testSyncerSuite) TestMysqlSemiPositionSync(c *C) { t.setupTest(c, mysql.MySQLFlavor) - err := t.b.EnableSemiSync() - if err != nil { - c.Skip(fmt.Sprintf("mysql doest not support semi synchronous replication %v", err)) - } + t.b.cfg.SemiSyncEnabled = true t.testPositionSync(c) } @@ -249,6 +361,7 @@ func (t *testSyncerSuite) TestMysqlBinlogCodec(c *C) { var wg sync.WaitGroup wg.Add(1) + defer wg.Wait() go func() { defer wg.Done() @@ -263,7 +376,7 @@ func (t *testSyncerSuite) TestMysqlBinlogCodec(c *C) { os.RemoveAll("./var") err := t.b.StartBackup("./var", mysql.Position{"", uint32(0)}, 2*time.Second) - c.Check(err, Equals, ErrGetEventTimeout) + c.Assert(err, IsNil) p := NewBinlogParser() diff --git a/vendor/github.com/siddontang/go-mysql/replication/row_event.go b/vendor/github.com/siddontang/go-mysql/replication/row_event.go index ff6c913..b345ae1 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/row_event.go +++ b/vendor/github.com/siddontang/go-mysql/replication/row_event.go @@ -10,6 +10,7 @@ import ( "time" "github.com/juju/errors" + "github.com/ngaut/log" . "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go/hack" ) @@ -155,7 +156,8 @@ func (e *TableMapEvent) decodeMeta(data []byte) error { case MYSQL_TYPE_BLOB, MYSQL_TYPE_DOUBLE, MYSQL_TYPE_FLOAT, - MYSQL_TYPE_GEOMETRY: + MYSQL_TYPE_GEOMETRY, + MYSQL_TYPE_JSON: e.ColumnMeta[i] = uint16(data[pos]) pos++ case MYSQL_TYPE_TIME2, @@ -180,6 +182,7 @@ func (e *TableMapEvent) decodeMeta(data []byte) error { func (e *TableMapEvent) Dump(w io.Writer) { fmt.Fprintf(w, "TableID: %d\n", e.TableID) + fmt.Fprintf(w, "TableID size: %d\n", e.tableIDSize) fmt.Fprintf(w, "Flags: %d\n", e.Flags) fmt.Fprintf(w, "Schema: %s\n", e.Schema) fmt.Fprintf(w, "Table: %s\n", e.Table) @@ -189,6 +192,9 @@ func (e *TableMapEvent) Dump(w io.Writer) { fmt.Fprintln(w) } +// RowsEventStmtEndFlag is set in the end of the statement. +const RowsEventStmtEndFlag = 0x01 + type RowsEvent struct { //0, 1, 2 Version int @@ -257,13 +263,13 @@ func (e *RowsEvent) Decode(data []byte) error { var err error // ... repeat rows until event-end - // Why using pos < len(data) - 1 instead of origin pos < len(data) to check? - // See mysql - // https://github.com/mysql/mysql-server/blob/5.7/sql/log_event.cc#L9006 - // https://github.com/mysql/mysql-server/blob/5.7/sql/log_event.cc#L2492 - // A user panics here but he can't give me more information, and using len(data) - 1 - // fixes this panic, so I will try to construct a test case for this later. - for pos < len(data)-1 { + defer func() { + if r := recover(); r != nil { + log.Fatalf("parse rows event panic %v, data %q, parsed rows %#v, table map %#v\n%s", r, data, e, e.Table, Pstack()) + } + }() + + for pos < len(data) { if n, err = e.decodeRows(data[pos:], e.Table, e.ColumnBitmap1); err != nil { return errors.Trace(err) } @@ -280,12 +286,23 @@ func (e *RowsEvent) Decode(data []byte) error { return nil } +func isBitSet(bitmap []byte, i int) bool { + return bitmap[i>>3]&(1<<(uint(i)&7)) > 0 +} + func (e *RowsEvent) decodeRows(data []byte, table *TableMapEvent, bitmap []byte) (int, error) { row := make([]interface{}, e.ColumnCount) pos := 0 - count := (bitCount(bitmap) + 7) / 8 + // refer: https://github.com/alibaba/canal/blob/c3e38e50e269adafdd38a48c63a1740cde304c67/dbsync/src/main/java/com/taobao/tddl/dbsync/binlog/event/RowsLogBuffer.java#L63 + count := 0 + for i := 0; i < int(e.ColumnCount); i++ { + if isBitSet(bitmap, i) { + count++ + } + } + count = (count + 7) / 8 nullBitmap := data[pos : pos+count] pos += count @@ -295,26 +312,24 @@ func (e *RowsEvent) decodeRows(data []byte, table *TableMapEvent, bitmap []byte) var n int var err error for i := 0; i < int(e.ColumnCount); i++ { - if bitGet(bitmap, i) == 0 { + if !isBitSet(bitmap, i) { continue } isNull := (uint32(nullBitmap[nullbitIndex/8]) >> uint32(nullbitIndex%8)) & 0x01 + nullbitIndex++ if isNull > 0 { row[i] = nil - nullbitIndex++ continue } row[i], n, err = e.decodeValue(data[pos:], table.ColumnType[i], table.ColumnMeta[i]) if err != nil { - return 0, nil + return 0, err } pos += n - - nullbitIndex++ } e.Rows = append(e.Rows, row) @@ -436,37 +451,31 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16) (v interface{ err = fmt.Errorf("Unknown ENUM packlen=%d", l) } case MYSQL_TYPE_SET: - nbits := meta & 0xFF - n = int(nbits+7) / 8 + n = int(meta & 0xFF) + nbits := n * 8 - v, err = decodeBit(data, int(nbits), n) + v, err = decodeBit(data, nbits, n) case MYSQL_TYPE_BLOB: - switch meta { - case 1: - length = int(data[0]) - v = data[1 : 1+length] - n = length + 1 - case 2: - length = int(binary.LittleEndian.Uint16(data)) - v = data[2 : 2+length] - n = length + 2 - case 3: - length = int(FixedLengthInt(data[0:3])) - v = data[3 : 3+length] - n = length + 3 - case 4: - length = int(binary.LittleEndian.Uint32(data)) - v = data[4 : 4+length] - n = length + 4 - default: - err = fmt.Errorf("invalid blob packlen = %d", meta) - } + v, n, err = decodeBlob(data, meta) case MYSQL_TYPE_VARCHAR, MYSQL_TYPE_VAR_STRING: length = int(meta) v, n = decodeString(data, length) case MYSQL_TYPE_STRING: v, n = decodeString(data, length) + case MYSQL_TYPE_JSON: + // Refer https://github.com/shyiko/mysql-binlog-connector-java/blob/8f9132ee773317e00313204beeae8ddcaa43c1b4/src/main/java/com/github/shyiko/mysql/binlog/event/deserialization/AbstractRowsEventDataDeserializer.java#L344 + length = int(binary.LittleEndian.Uint32(data[0:])) + n = length + int(meta) + v, err = decodeJsonBinary(data[meta:n]) + case MYSQL_TYPE_GEOMETRY: + // MySQL saves Geometry as Blob in binlog + // Seem that the binary format is SRID (4 bytes) + WKB, outer can use + // MySQL GeoFromWKB or others to create the geometry data. + // Refer https://dev.mysql.com/doc/refman/5.7/en/gis-wkb-functions.html + // I also find some go libs to handle WKB if possible + // see https://github.com/twpayne/go-geom or https://github.com/paulmach/go.geo + v, n, err = decodeBlob(data, meta) default: err = fmt.Errorf("unsupport type %d in binlog and don't know how to handle", tp) } @@ -591,7 +600,7 @@ func decodeBit(data []byte, nbits int, length int) (value int64, err error) { return } -func decodeTimestamp2(data []byte, dec uint16) (interface{}, int, error) { +func decodeTimestamp2(data []byte, dec uint16) (string, int, error) { //get timestamp binary length n := int(4 + (dec+1)/2) sec := int64(binary.BigEndian.Uint32(data[0:4])) @@ -610,12 +619,12 @@ func decodeTimestamp2(data []byte, dec uint16) (interface{}, int, error) { } t := time.Unix(sec, usec*1000) - return t, n, nil + return t.Format(TimeFormat), n, nil } const DATETIMEF_INT_OFS int64 = 0x8000000000 -func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { +func decodeDatetime2(data []byte, dec uint16) (string, int, error) { //get datetime binary length n := int(5 + (dec+1)/2) @@ -641,7 +650,6 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { tmp = -tmp } - //ingore second part, no precision now var secPart int64 = tmp % (1 << 24) ymdhms := tmp >> 24 @@ -658,9 +666,9 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { hour := int((hms >> 12)) if secPart != 0 { - return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%d", year, month, day, hour, minute, second, secPart), n, nil // commented by Shlomi Noach. Yes I know about `git blame` + return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%06d", year, month, day, hour, minute, second, secPart), n, nil } - return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second), n, nil // commented by Shlomi Noach. Yes I know about `git blame` + return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second), n, nil } const TIMEF_OFS int64 = 0x800000000000 @@ -733,19 +741,46 @@ func decodeTime2(data []byte, dec uint16) (string, int, error) { sign = "-" } - //ingore second part, no precision now - //var secPart int64 = tmp % (1 << 24) - hms = tmp >> 24 hour := (hms >> 12) % (1 << 10) /* 10 bits starting at 12th */ minute := (hms >> 6) % (1 << 6) /* 6 bits starting at 6th */ second := hms % (1 << 6) /* 6 bits starting at 0th */ - // secondPart := tmp % (1 << 24) + secPart := tmp % (1 << 24) + + if secPart != 0 { + return fmt.Sprintf("%s%02d:%02d:%02d.%06d", sign, hour, minute, second, secPart), n, nil + } return fmt.Sprintf("%s%02d:%02d:%02d", sign, hour, minute, second), n, nil } +func decodeBlob(data []byte, meta uint16) (v []byte, n int, err error) { + var length int + switch meta { + case 1: + length = int(data[0]) + v = data[1 : 1+length] + n = length + 1 + case 2: + length = int(binary.LittleEndian.Uint16(data)) + v = data[2 : 2+length] + n = length + 2 + case 3: + length = int(FixedLengthInt(data[0:3])) + v = data[3 : 3+length] + n = length + 3 + case 4: + length = int(binary.LittleEndian.Uint32(data)) + v = data[4 : 4+length] + n = length + 4 + default: + err = fmt.Errorf("invalid blob packlen = %d", meta) + } + + return +} + func (e *RowsEvent) Dump(w io.Writer) { fmt.Fprintf(w, "TableID: %d\n", e.TableID) fmt.Fprintf(w, "Flags: %d\n", e.Flags) diff --git a/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go b/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go index 539bab6..2d63bb5 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go +++ b/vendor/github.com/siddontang/go-mysql/replication/row_event_test.go @@ -3,7 +3,7 @@ package replication import ( "fmt" - . "gopkg.in/check.v1" + . "github.com/pingcap/check" ) type testDecodeSuite struct{} @@ -326,3 +326,106 @@ func (_ *testDecodeSuite) TestDecodeDecimal(c *C) { c.Assert(value, DecodeDecimalsEquals, pos, err, tc.Expected, tc.ExpectedPos, tc.ExpectedErr, i) } } + +func (_ *testDecodeSuite) TestLastNull(c *C) { + // Table format: + // desc funnytable; + // +-------+------------+------+-----+---------+-------+ + // | Field | Type | Null | Key | Default | Extra | + // +-------+------------+------+-----+---------+-------+ + // | value | tinyint(4) | YES | | NULL | | + // +-------+------------+------+-----+---------+-------+ + + // insert into funnytable values (1), (2), (null); + // insert into funnytable values (1), (null), (2); + // all must get 3 rows + + tableMapEventData := []byte("\xd3\x01\x00\x00\x00\x00\x01\x00\x04test\x00\nfunnytable\x00\x01\x01\x00\x01") + + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + err := tableMapEvent.Decode(tableMapEventData) + c.Assert(err, IsNil) + + rows := new(RowsEvent) + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + tbls := [][]byte{ + []byte("\xd3\x01\x00\x00\x00\x00\x01\x00\x02\x00\x01\xff\xfe\x01\xff\xfe\x02"), + []byte("\xd3\x01\x00\x00\x00\x00\x01\x00\x02\x00\x01\xff\xfe\x01\xfe\x02\xff"), + } + + for _, tbl := range tbls { + rows.Rows = nil + err = rows.Decode(tbl) + c.Assert(err, IsNil) + c.Assert(rows.Rows, HasLen, 3) + } +} + +func (_ *testDecodeSuite) TestParseRowPanic(c *C) { + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + tableMapEvent.TableID = 1810 + tableMapEvent.ColumnType = []byte{3, 15, 15, 15, 9, 15, 15, 252, 3, 3, 3, 15, 3, 3, 3, 15, 3, 15, 1, 15, 3, 1, 252, 15, 15, 15} + tableMapEvent.ColumnMeta = []uint16{0, 108, 60, 765, 0, 765, 765, 4, 0, 0, 0, 765, 0, 0, 0, 3, 0, 3, 0, 765, 0, 0, 2, 108, 108, 108} + + rows := new(RowsEvent) + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + data := []byte{18, 7, 0, 0, 0, 0, 1, 0, 2, 0, 26, 1, 1, 16, 252, 248, 142, 63, 0, 0, 13, 0, 0, 0, 13, 0, 0, 0} + + err := rows.Decode(data) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][0], Equals, int32(16270)) +} + +func (_ *testDecodeSuite) TestParseJson(c *C) { + // Table format: + // mysql> desc t10; + // +-------+---------------+------+-----+---------+-------+ + // | Field | Type | Null | Key | Default | Extra | + // +-------+---------------+------+-----+---------+-------+ + // | c1 | json | YES | | NULL | | + // | c2 | decimal(10,0) | YES | | NULL | | + // +-------+---------------+------+-----+---------+-------+ + + // CREATE TABLE `t10` ( + // `c1` json DEFAULT NULL, + // `c2` decimal(10,0) + // ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + + // INSERT INTO `t10` (`c2`) VALUES (1); + // INSERT INTO `t10` (`c1`, `c2`) VALUES ('{"key1": "value1", "key2": "value2"}', 1); + + tableMapEventData := []byte("m\x00\x00\x00\x00\x00\x01\x00\x04test\x00\x03t10\x00\x02\xf5\xf6\x03\x04\n\x00\x03") + + tableMapEvent := new(TableMapEvent) + tableMapEvent.tableIDSize = 6 + err := tableMapEvent.Decode(tableMapEventData) + c.Assert(err, IsNil) + + rows := new(RowsEvent) + rows.tableIDSize = 6 + rows.tables = make(map[uint64]*TableMapEvent) + rows.tables[tableMapEvent.TableID] = tableMapEvent + rows.Version = 2 + + tbls := [][]byte{ + []byte("m\x00\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfd\x80\x00\x00\x00\x01"), + []byte("m\x00\x00\x00\x00\x00\x01\x00\x02\x00\x02\xff\xfc)\x00\x00\x00\x00\x02\x00(\x00\x12\x00\x04\x00\x16\x00\x04\x00\f\x1a\x00\f!\x00key1key2\x06value1\x06value2\x80\x00\x00\x00\x01"), + } + + for _, tbl := range tbls { + rows.Rows = nil + err = rows.Decode(tbl) + c.Assert(err, IsNil) + c.Assert(rows.Rows[0][1], Equals, float64(1)) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/schema/schema.go b/vendor/github.com/siddontang/go-mysql/schema/schema.go new file mode 100644 index 0000000..86d2128 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/schema/schema.go @@ -0,0 +1,215 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schema + +import ( + "fmt" + "strings" + + "github.com/juju/errors" + "github.com/siddontang/go-mysql/mysql" +) + +const ( + TYPE_NUMBER = iota + 1 // tinyint, smallint, mediumint, int, bigint, year + TYPE_FLOAT // float, double + TYPE_ENUM // enum + TYPE_SET // set + TYPE_STRING // other + TYPE_DATETIME // datetime + TYPE_TIMESTAMP // timestamp + TYPE_DATE // date + TYPE_TIME // time + TYPE_BIT // bit + TYPE_JSON // json +) + +type TableColumn struct { + Name string + Type int + IsAuto bool + EnumValues []string + SetValues []string +} + +type Index struct { + Name string + Columns []string + Cardinality []uint64 +} + +type Table struct { + Schema string + Name string + + Columns []TableColumn + Indexes []*Index + PKColumns []int +} + +func (ta *Table) String() string { + return fmt.Sprintf("%s.%s", ta.Schema, ta.Name) +} + +func (ta *Table) AddColumn(name string, columnType string, extra string) { + index := len(ta.Columns) + ta.Columns = append(ta.Columns, TableColumn{Name: name}) + + if strings.Contains(columnType, "int") || strings.HasPrefix(columnType, "year") { + ta.Columns[index].Type = TYPE_NUMBER + } else if strings.HasPrefix(columnType, "float") || + strings.HasPrefix(columnType, "double") || + strings.HasPrefix(columnType, "decimal") { + ta.Columns[index].Type = TYPE_FLOAT + } else if strings.HasPrefix(columnType, "enum") { + ta.Columns[index].Type = TYPE_ENUM + ta.Columns[index].EnumValues = strings.Split(strings.Replace( + strings.TrimSuffix( + strings.TrimPrefix( + columnType, "enum("), + ")"), + "'", "", -1), + ",") + } else if strings.HasPrefix(columnType, "set") { + ta.Columns[index].Type = TYPE_SET + ta.Columns[index].SetValues = strings.Split(strings.Replace( + strings.TrimSuffix( + strings.TrimPrefix( + columnType, "set("), + ")"), + "'", "", -1), + ",") + } else if strings.HasPrefix(columnType, "datetime") { + ta.Columns[index].Type = TYPE_DATETIME + } else if strings.HasPrefix(columnType, "timestamp") { + ta.Columns[index].Type = TYPE_TIMESTAMP + } else if strings.HasPrefix(columnType, "time") { + ta.Columns[index].Type = TYPE_TIME + } else if "date" == columnType { + ta.Columns[index].Type = TYPE_DATE + } else if strings.HasPrefix(columnType, "bit") { + ta.Columns[index].Type = TYPE_BIT + } else if strings.HasPrefix(columnType, "json") { + ta.Columns[index].Type = TYPE_JSON + } else { + ta.Columns[index].Type = TYPE_STRING + } + + if extra == "auto_increment" { + ta.Columns[index].IsAuto = true + } +} + +func (ta *Table) FindColumn(name string) int { + for i, col := range ta.Columns { + if col.Name == name { + return i + } + } + return -1 +} + +func (ta *Table) GetPKColumn(index int) *TableColumn { + return &ta.Columns[ta.PKColumns[index]] +} + +func (ta *Table) AddIndex(name string) (index *Index) { + index = NewIndex(name) + ta.Indexes = append(ta.Indexes, index) + return index +} + +func NewIndex(name string) *Index { + return &Index{name, make([]string, 0, 8), make([]uint64, 0, 8)} +} + +func (idx *Index) AddColumn(name string, cardinality uint64) { + idx.Columns = append(idx.Columns, name) + if cardinality == 0 { + cardinality = uint64(len(idx.Cardinality) + 1) + } + idx.Cardinality = append(idx.Cardinality, cardinality) +} + +func (idx *Index) FindColumn(name string) int { + for i, colName := range idx.Columns { + if name == colName { + return i + } + } + return -1 +} + +func NewTable(conn mysql.Executer, schema string, name string) (*Table, error) { + ta := &Table{ + Schema: schema, + Name: name, + Columns: make([]TableColumn, 0, 16), + Indexes: make([]*Index, 0, 8), + } + + if err := ta.fetchColumns(conn); err != nil { + return nil, err + } + + if err := ta.fetchIndexes(conn); err != nil { + return nil, err + } + + return ta, nil +} + +func (ta *Table) fetchColumns(conn mysql.Executer) error { + r, err := conn.Execute(fmt.Sprintf("describe `%s`.`%s`", ta.Schema, ta.Name)) + if err != nil { + return errors.Trace(err) + } + + for i := 0; i < r.RowNumber(); i++ { + name, _ := r.GetString(i, 0) + colType, _ := r.GetString(i, 1) + extra, _ := r.GetString(i, 5) + + ta.AddColumn(name, colType, extra) + } + + return nil +} + +func (ta *Table) fetchIndexes(conn mysql.Executer) error { + r, err := conn.Execute(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name)) + if err != nil { + return errors.Trace(err) + } + var currentIndex *Index + currentName := "" + + for i := 0; i < r.RowNumber(); i++ { + indexName, _ := r.GetString(i, 2) + if currentName != indexName { + currentIndex = ta.AddIndex(indexName) + currentName = indexName + } + cardinality, _ := r.GetUint(i, 6) + colName, _ := r.GetString(i, 4) + currentIndex.AddColumn(colName, cardinality) + } + + if len(ta.Indexes) == 0 { + return nil + } + + pkIndex := ta.Indexes[0] + if pkIndex.Name != "PRIMARY" { + return nil + } + + ta.PKColumns = make([]int, len(pkIndex.Columns)) + for i, pkCol := range pkIndex.Columns { + ta.PKColumns[i] = ta.FindColumn(pkCol) + } + + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/schema/schema_test.go b/vendor/github.com/siddontang/go-mysql/schema/schema_test.go new file mode 100644 index 0000000..327c622 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/schema/schema_test.go @@ -0,0 +1,84 @@ +package schema + +import ( + "flag" + "fmt" + "testing" + + . "github.com/pingcap/check" + "github.com/siddontang/go-mysql/client" +) + +// use docker mysql for test +var host = flag.String("host", "127.0.0.1", "MySQL host") + +func Test(t *testing.T) { + TestingT(t) +} + +type schemaTestSuite struct { + conn *client.Conn +} + +var _ = Suite(&schemaTestSuite{}) + +func (s *schemaTestSuite) SetUpSuite(c *C) { + var err error + s.conn, err = client.Connect(fmt.Sprintf("%s:%d", *host, 3306), "root", "", "test") + c.Assert(err, IsNil) +} + +func (s *schemaTestSuite) TearDownSuite(c *C) { + if s.conn != nil { + s.conn.Close() + } +} + +func (s *schemaTestSuite) TestSchema(c *C) { + _, err := s.conn.Execute(`DROP TABLE IF EXISTS schema_test`) + c.Assert(err, IsNil) + + str := ` + CREATE TABLE IF NOT EXISTS schema_test ( + id INT, + id1 INT, + id2 INT, + name VARCHAR(256), + e ENUM("a", "b", "c"), + se SET('a', 'b', 'c'), + f FLOAT, + d DECIMAL(2, 1), + PRIMARY KEY(id2, id), + UNIQUE (id1), + INDEX name_idx (name) + ) ENGINE = INNODB; + ` + + _, err = s.conn.Execute(str) + c.Assert(err, IsNil) + + ta, err := NewTable(s.conn, "test", "schema_test") + c.Assert(err, IsNil) + + c.Assert(ta.Columns, HasLen, 8) + c.Assert(ta.Indexes, HasLen, 3) + c.Assert(ta.PKColumns, DeepEquals, []int{2, 0}) + c.Assert(ta.Indexes[0].Columns, HasLen, 2) + c.Assert(ta.Indexes[0].Name, Equals, "PRIMARY") + c.Assert(ta.Indexes[2].Name, Equals, "name_idx") + c.Assert(ta.Columns[4].EnumValues, DeepEquals, []string{"a", "b", "c"}) + c.Assert(ta.Columns[5].SetValues, DeepEquals, []string{"a", "b", "c"}) + c.Assert(ta.Columns[7].Type, Equals, TYPE_FLOAT) +} + +func (s *schemaTestSuite) TestQuoteSchema(c *C) { + str := "CREATE TABLE IF NOT EXISTS `a-b_test` (`a.b` INT) ENGINE = INNODB" + + _, err := s.conn.Execute(str) + c.Assert(err, IsNil) + + ta, err := NewTable(s.conn, "test", "a-b_test") + c.Assert(err, IsNil) + + c.Assert(ta.Columns[0].Name, Equals, "a.b") +} diff --git a/vendor/github.com/siddontang/go-mysql/server/auth.go b/vendor/github.com/siddontang/go-mysql/server/auth.go new file mode 100644 index 0000000..b66ea4e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/auth.go @@ -0,0 +1,119 @@ +package server + +import ( + "bytes" + "encoding/binary" + + . "github.com/siddontang/go-mysql/mysql" +) + +func (c *Conn) writeInitialHandshake() error { + capability := CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | + CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | + CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION + + data := make([]byte, 4, 128) + + //min version 10 + data = append(data, 10) + + //server version[00] + data = append(data, ServerVersion...) + data = append(data, 0) + + //connection id + data = append(data, byte(c.connectionID), byte(c.connectionID>>8), byte(c.connectionID>>16), byte(c.connectionID>>24)) + + //auth-plugin-data-part-1 + data = append(data, c.salt[0:8]...) + + //filter [00] + data = append(data, 0) + + //capability flag lower 2 bytes, using default capability here + data = append(data, byte(capability), byte(capability>>8)) + + //charset, utf-8 default + data = append(data, uint8(DEFAULT_COLLATION_ID)) + + //status + data = append(data, byte(c.status), byte(c.status>>8)) + + //below 13 byte may not be used + //capability flag upper 2 bytes, using default capability here + data = append(data, byte(capability>>16), byte(capability>>24)) + + //filter [0x15], for wireshark dump, value is 0x15 + data = append(data, 0x15) + + //reserved 10 [00] + data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + + //auth-plugin-data-part-2 + data = append(data, c.salt[8:]...) + + //filter [00] + data = append(data, 0) + + return c.WritePacket(data) +} + +func (c *Conn) readHandshakeResponse(password string) error { + data, err := c.ReadPacket() + + if err != nil { + return err + } + + pos := 0 + + //capability + c.capability = binary.LittleEndian.Uint32(data[:4]) + pos += 4 + + //skip max packet size + pos += 4 + + //charset, skip, if you want to use another charset, use set names + //c.collation = CollationId(data[pos]) + pos++ + + //skip reserved 23[00] + pos += 23 + + //user name + user := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) + pos += len(user) + 1 + + if c.user != user { + return NewDefaultError(ER_NO_SUCH_USER, user, c.RemoteAddr().String()) + } + + //auth length and auth + authLen := int(data[pos]) + pos++ + auth := data[pos : pos+authLen] + + checkAuth := CalcPassword(c.salt, []byte(password)) + + if !bytes.Equal(auth, checkAuth) { + return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.RemoteAddr().String(), c.user, "Yes") + } + + pos += authLen + + if c.capability|CLIENT_CONNECT_WITH_DB > 0 { + if len(data[pos:]) == 0 { + return nil + } + + db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) + pos += len(db) + 1 + + if err = c.h.UseDB(db); err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/server/command.go b/vendor/github.com/siddontang/go-mysql/server/command.go new file mode 100644 index 0000000..fb7dcdf --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/command.go @@ -0,0 +1,151 @@ +package server + +import ( + "bytes" + "fmt" + + . "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go/hack" +) + +type Handler interface { + //handle COM_INIT_DB command, you can check whether the dbName is valid, or other. + UseDB(dbName string) error + //handle COM_QUERY comamnd, like SELECT, INSERT, UPDATE, etc... + //If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the repsonse, otherwise, we will send Result + HandleQuery(query string) (*Result, error) + //handle COM_FILED_LIST command + HandleFieldList(table string, fieldWildcard string) ([]*Field, error) + //handle COM_STMT_PREPARE, params is the param number for this statement, columns is the column number + //context will be used later for statement execute + HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) + //handle COM_STMT_EXECUTE, context is the previous one set in prepare + //query is the statement prepare query, and args is the params for this statement + HandleStmtExecute(context interface{}, query string, args []interface{}) (*Result, error) + //handle COM_STMT_CLOSE, context is the previous one set in prepare + //this handler has no response + HandleStmtClose(context interface{}) error +} + +func (c *Conn) HandleCommand() error { + if c.Conn == nil { + return fmt.Errorf("connection closed") + } + + data, err := c.ReadPacket() + if err != nil { + c.Close() + c.Conn = nil + return err + } + + v := c.dispatch(data) + + err = c.writeValue(v) + + if c.Conn != nil { + c.ResetSequence() + } + + if err != nil { + c.Close() + c.Conn = nil + } + return err +} + +func (c *Conn) dispatch(data []byte) interface{} { + cmd := data[0] + data = data[1:] + + switch cmd { + case COM_QUIT: + c.Close() + c.Conn = nil + return noResponse{} + case COM_QUERY: + if r, err := c.h.HandleQuery(hack.String(data)); err != nil { + return err + } else { + return r + } + case COM_PING: + return nil + case COM_INIT_DB: + if err := c.h.UseDB(hack.String(data)); err != nil { + return err + } else { + return nil + } + case COM_FIELD_LIST: + index := bytes.IndexByte(data, 0x00) + table := hack.String(data[0:index]) + wildcard := hack.String(data[index+1:]) + + if fs, err := c.h.HandleFieldList(table, wildcard); err != nil { + return err + } else { + return fs + } + case COM_STMT_PREPARE: + c.stmtID++ + st := new(Stmt) + st.ID = c.stmtID + st.Query = hack.String(data) + var err error + if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil { + return err + } else { + st.ResetParams() + c.stmts[c.stmtID] = st + return st + } + case COM_STMT_EXECUTE: + if r, err := c.handleStmtExecute(data); err != nil { + return err + } else { + return r + } + case COM_STMT_CLOSE: + c.handleStmtClose(data) + return noResponse{} + case COM_STMT_SEND_LONG_DATA: + c.handleStmtSendLongData(data) + return noResponse{} + case COM_STMT_RESET: + if r, err := c.handleStmtReset(data); err != nil { + return err + } else { + return r + } + default: + msg := fmt.Sprintf("command %d is not supported now", cmd) + return NewError(ER_UNKNOWN_ERROR, msg) + } + + return fmt.Errorf("command %d is not handled correctly", cmd) +} + +type EmptyHandler struct { +} + +func (h EmptyHandler) UseDB(dbName string) error { + return nil +} +func (h EmptyHandler) HandleQuery(query string) (*Result, error) { + return nil, fmt.Errorf("not supported now") +} + +func (h EmptyHandler) HandleFieldList(table string, fieldWildcard string) ([]*Field, error) { + return nil, fmt.Errorf("not supported now") +} +func (h EmptyHandler) HandleStmtPrepare(query string) (int, int, interface{}, error) { + return 0, 0, nil, fmt.Errorf("not supported now") +} +func (h EmptyHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*Result, error) { + return nil, fmt.Errorf("not supported now") +} + +func (h EmptyHandler) HandleStmtClose(context interface{}) error { + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/server/conn.go b/vendor/github.com/siddontang/go-mysql/server/conn.go new file mode 100644 index 0000000..d6ea846 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/conn.go @@ -0,0 +1,113 @@ +package server + +import ( + "net" + "sync/atomic" + + . "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/packet" + "github.com/siddontang/go/sync2" +) + +/* + Conn acts like a MySQL server connection, you can use MySQL client to communicate with it. +*/ +type Conn struct { + *packet.Conn + + capability uint32 + + connectionID uint32 + + status uint16 + + user string + + salt []byte + + h Handler + + stmts map[uint32]*Stmt + stmtID uint32 + + closed sync2.AtomicBool +} + +var baseConnID uint32 = 10000 + +func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, error) { + c := new(Conn) + + c.h = h + + c.user = user + c.Conn = packet.NewConn(conn) + + c.connectionID = atomic.AddUint32(&baseConnID, 1) + + c.stmts = make(map[uint32]*Stmt) + + c.salt, _ = RandomBuf(20) + + c.closed.Set(false) + + if err := c.handshake(password); err != nil { + c.Close() + return nil, err + } + + return c, nil +} + +func (c *Conn) handshake(password string) error { + if err := c.writeInitialHandshake(); err != nil { + return err + } + + if err := c.readHandshakeResponse(password); err != nil { + c.writeError(err) + + return err + } + + if err := c.writeOK(nil); err != nil { + return err + } + + c.ResetSequence() + + return nil +} + +func (c *Conn) Close() { + c.closed.Set(true) + c.Conn.Close() +} + +func (c *Conn) Closed() bool { + return c.closed.Get() +} + +func (c *Conn) GetUser() string { + return c.user +} + +func (c *Conn) ConnectionID() uint32 { + return c.connectionID +} + +func (c *Conn) IsAutoCommit() bool { + return c.status&SERVER_STATUS_AUTOCOMMIT > 0 +} + +func (c *Conn) IsInTransaction() bool { + return c.status&SERVER_STATUS_IN_TRANS > 0 +} + +func (c *Conn) SetInTransaction() { + c.status |= SERVER_STATUS_IN_TRANS +} + +func (c *Conn) ClearInTransaction() { + c.status &= ^SERVER_STATUS_IN_TRANS +} diff --git a/vendor/github.com/siddontang/go-mysql/server/resp.go b/vendor/github.com/siddontang/go-mysql/server/resp.go new file mode 100644 index 0000000..1123032 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/resp.go @@ -0,0 +1,142 @@ +package server + +import ( + "fmt" + + . "github.com/siddontang/go-mysql/mysql" +) + +func (c *Conn) writeOK(r *Result) error { + if r == nil { + r = &Result{} + } + + r.Status |= c.status + + data := make([]byte, 4, 32) + + data = append(data, OK_HEADER) + + data = append(data, PutLengthEncodedInt(r.AffectedRows)...) + data = append(data, PutLengthEncodedInt(r.InsertId)...) + + if c.capability&CLIENT_PROTOCOL_41 > 0 { + data = append(data, byte(r.Status), byte(r.Status>>8)) + data = append(data, 0, 0) + } + + return c.WritePacket(data) +} + +func (c *Conn) writeError(e error) error { + var m *MyError + var ok bool + if m, ok = e.(*MyError); !ok { + m = NewError(ER_UNKNOWN_ERROR, e.Error()) + } + + data := make([]byte, 4, 16+len(m.Message)) + + data = append(data, ERR_HEADER) + data = append(data, byte(m.Code), byte(m.Code>>8)) + + if c.capability&CLIENT_PROTOCOL_41 > 0 { + data = append(data, '#') + data = append(data, m.State...) + } + + data = append(data, m.Message...) + + return c.WritePacket(data) +} + +func (c *Conn) writeEOF() error { + data := make([]byte, 4, 9) + + data = append(data, EOF_HEADER) + if c.capability&CLIENT_PROTOCOL_41 > 0 { + data = append(data, 0, 0) + data = append(data, byte(c.status), byte(c.status>>8)) + } + + return c.WritePacket(data) +} + +func (c *Conn) writeResultset(r *Resultset) error { + columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) + + data := make([]byte, 4, 1024) + + data = append(data, columnLen...) + if err := c.WritePacket(data); err != nil { + return err + } + + for _, v := range r.Fields { + data = data[0:4] + data = append(data, v.Dump()...) + if err := c.WritePacket(data); err != nil { + return err + } + } + + if err := c.writeEOF(); err != nil { + return err + } + + for _, v := range r.RowDatas { + data = data[0:4] + data = append(data, v...) + if err := c.WritePacket(data); err != nil { + return err + } + } + + if err := c.writeEOF(); err != nil { + return err + } + + return nil +} + +func (c *Conn) writeFieldList(fs []*Field) error { + data := make([]byte, 4, 1024) + + for _, v := range fs { + data = data[0:4] + data = append(data, v.Dump()...) + if err := c.WritePacket(data); err != nil { + return err + } + } + + if err := c.writeEOF(); err != nil { + return err + } + return nil +} + +type noResponse struct{} + +func (c *Conn) writeValue(value interface{}) error { + switch v := value.(type) { + case noResponse: + return nil + case error: + return c.writeError(v) + case nil: + return c.writeOK(nil) + case *Result: + if v != nil && v.Resultset != nil { + return c.writeResultset(v.Resultset) + } else { + return c.writeOK(v) + } + case []*Field: + return c.writeFieldList(v) + case *Stmt: + return c.writePrepare(v) + default: + return fmt.Errorf("invalid response type %T", value) + } +} diff --git a/vendor/github.com/siddontang/go-mysql/server/server_test.go b/vendor/github.com/siddontang/go-mysql/server/server_test.go new file mode 100644 index 0000000..54bcf05 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/server_test.go @@ -0,0 +1,230 @@ +package server + +import ( + "database/sql" + "flag" + "fmt" + "net" + "strings" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/juju/errors" + . "github.com/pingcap/check" + mysql "github.com/siddontang/go-mysql/mysql" +) + +var testAddr = flag.String("addr", "127.0.0.1:4000", "MySQL proxy server address") +var testUser = flag.String("user", "root", "MySQL user") +var testPassword = flag.String("pass", "", "MySQL password") +var testDB = flag.String("db", "test", "MySQL test database") + +func Test(t *testing.T) { + TestingT(t) +} + +type serverTestSuite struct { + db *sql.DB + + l net.Listener +} + +var _ = Suite(&serverTestSuite{}) + +type testHandler struct { + s *serverTestSuite +} + +func (h *testHandler) UseDB(dbName string) error { + return nil +} + +func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { + ss := strings.Split(query, " ") + switch strings.ToLower(ss[0]) { + case "select": + var r *mysql.Resultset + var err error + //for handle go mysql driver select @@max_allowed_packet + if strings.Contains(strings.ToLower(query), "max_allowed_packet") { + r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ + []interface{}{mysql.MaxPayloadLen}, + }, binary) + } else { + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ + []interface{}{1, "hello world"}, + }, binary) + } + + if err != nil { + return nil, errors.Trace(err) + } else { + return &mysql.Result{0, 0, 0, r}, nil + } + case "insert": + return &mysql.Result{0, 1, 0, nil}, nil + case "delete": + return &mysql.Result{0, 0, 1, nil}, nil + case "update": + return &mysql.Result{0, 0, 1, nil}, nil + case "replace": + return &mysql.Result{0, 0, 1, nil}, nil + default: + return nil, fmt.Errorf("invalid query %s", query) + } + + return nil, nil +} + +func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) { + return h.handleQuery(query, false) +} + +func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + return nil, nil +} +func (h *testHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) { + ss := strings.Split(sql, " ") + switch strings.ToLower(ss[0]) { + case "select": + params = 1 + columns = 2 + case "insert": + params = 2 + columns = 0 + case "replace": + params = 2 + columns = 0 + case "update": + params = 1 + columns = 0 + case "delete": + params = 1 + columns = 0 + default: + err = fmt.Errorf("invalid prepare %s", sql) + } + return params, columns, nil, err +} + +func (h *testHandler) HandleStmtClose(context interface{}) error { + return nil +} + +func (h *testHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) { + return h.handleQuery(query, true) +} + +func (s *serverTestSuite) SetUpSuite(c *C) { + var err error + + s.l, err = net.Listen("tcp", *testAddr) + c.Assert(err, IsNil) + + go s.onAccept(c) + + time.Sleep(500 * time.Millisecond) + + s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", *testUser, *testPassword, *testAddr, *testDB)) + c.Assert(err, IsNil) + + s.db.SetMaxIdleConns(4) +} + +func (s *serverTestSuite) TearDownSuite(c *C) { + if s.db != nil { + s.db.Close() + } + + if s.l != nil { + s.l.Close() + } +} + +func (s *serverTestSuite) onAccept(c *C) { + for { + conn, err := s.l.Accept() + if err != nil { + return + } + + go s.onConn(conn, c) + } +} + +func (s *serverTestSuite) onConn(conn net.Conn, c *C) { + co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) + c.Assert(err, IsNil) + + for { + err = co.HandleCommand() + if err != nil { + return + } + } +} + +func (s *serverTestSuite) TestSelect(c *C) { + var a int64 + var b string + + err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=1").Scan(&a, &b) + c.Assert(err, IsNil) + c.Assert(a, Equals, int64(1)) + c.Assert(b, Equals, "hello world") +} + +func (s *serverTestSuite) TestExec(c *C) { + r, err := s.db.Exec("INSERT INTO tbl (a, b) values (1, \"hello world\")") + c.Assert(err, IsNil) + i, _ := r.LastInsertId() + c.Assert(i, Equals, int64(1)) + + r, err = s.db.Exec("REPLACE INTO tbl (a, b) values (1, \"hello world\")") + c.Assert(err, IsNil) + i, _ = r.RowsAffected() + c.Assert(i, Equals, int64(1)) + + r, err = s.db.Exec("UPDATE tbl SET b = \"abc\" where a = 1") + c.Assert(err, IsNil) + i, _ = r.RowsAffected() + c.Assert(i, Equals, int64(1)) + + r, err = s.db.Exec("DELETE FROM tbl where a = 1") + c.Assert(err, IsNil) + i, _ = r.RowsAffected() + c.Assert(i, Equals, int64(1)) +} + +func (s *serverTestSuite) TestStmtSelect(c *C) { + var a int64 + var b string + + err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=?", 1).Scan(&a, &b) + c.Assert(err, IsNil) + c.Assert(a, Equals, int64(1)) + c.Assert(b, Equals, "hello world") +} + +func (s *serverTestSuite) TestStmtExec(c *C) { + r, err := s.db.Exec("INSERT INTO tbl (a, b) values (?, ?)", 1, "hello world") + c.Assert(err, IsNil) + i, _ := r.LastInsertId() + c.Assert(i, Equals, int64(1)) + + r, err = s.db.Exec("REPLACE INTO tbl (a, b) values (?, ?)", 1, "hello world") + c.Assert(err, IsNil) + i, _ = r.RowsAffected() + c.Assert(i, Equals, int64(1)) + + r, err = s.db.Exec("UPDATE tbl SET b = \"abc\" where a = ?", 1) + c.Assert(err, IsNil) + i, _ = r.RowsAffected() + c.Assert(i, Equals, int64(1)) + + r, err = s.db.Exec("DELETE FROM tbl where a = ?", 1) + c.Assert(err, IsNil) + i, _ = r.RowsAffected() + c.Assert(i, Equals, int64(1)) +} diff --git a/vendor/github.com/siddontang/go-mysql/server/stmt.go b/vendor/github.com/siddontang/go-mysql/server/stmt.go new file mode 100644 index 0000000..7a325d7 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/stmt.go @@ -0,0 +1,363 @@ +package server + +import ( + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" +) + +var paramFieldData []byte +var columnFieldData []byte + +func init() { + var p = &Field{Name: []byte("?")} + var c = &Field{} + + paramFieldData = p.Dump() + columnFieldData = c.Dump() +} + +type Stmt struct { + ID uint32 + Query string + + Params int + Columns int + + Args []interface{} + + Context interface{} +} + +func (s *Stmt) Rest(params int, columns int, context interface{}) { + s.Params = params + s.Columns = columns + s.Context = context + s.ResetParams() +} + +func (s *Stmt) ResetParams() { + s.Args = make([]interface{}, s.Params) +} + +func (c *Conn) writePrepare(s *Stmt) error { + data := make([]byte, 4, 128) + + //status ok + data = append(data, 0) + //stmt id + data = append(data, Uint32ToBytes(s.ID)...) + //number columns + data = append(data, Uint16ToBytes(uint16(s.Columns))...) + //number params + data = append(data, Uint16ToBytes(uint16(s.Params))...) + //filter [00] + data = append(data, 0) + //warning count + data = append(data, 0, 0) + + if err := c.WritePacket(data); err != nil { + return err + } + + if s.Params > 0 { + for i := 0; i < s.Params; i++ { + data = data[0:4] + data = append(data, []byte(paramFieldData)...) + + if err := c.WritePacket(data); err != nil { + return errors.Trace(err) + } + } + + if err := c.writeEOF(); err != nil { + return err + } + } + + if s.Columns > 0 { + for i := 0; i < s.Columns; i++ { + data = data[0:4] + data = append(data, []byte(columnFieldData)...) + + if err := c.WritePacket(data); err != nil { + return errors.Trace(err) + } + } + + if err := c.writeEOF(); err != nil { + return err + } + + } + return nil +} + +func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { + if len(data) < 9 { + return nil, ErrMalformPacket + } + + pos := 0 + id := binary.LittleEndian.Uint32(data[0:4]) + pos += 4 + + s, ok := c.stmts[id] + if !ok { + return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, + strconv.FormatUint(uint64(id), 10), "stmt_execute") + } + + flag := data[pos] + pos++ + //now we only support CURSOR_TYPE_NO_CURSOR flag + if flag != 0 { + return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag)) + } + + //skip iteration-count, always 1 + pos += 4 + + var nullBitmaps []byte + var paramTypes []byte + var paramValues []byte + + paramNum := s.Params + + if paramNum > 0 { + nullBitmapLen := (s.Params + 7) >> 3 + if len(data) < (pos + nullBitmapLen + 1) { + return nil, ErrMalformPacket + } + nullBitmaps = data[pos : pos+nullBitmapLen] + pos += nullBitmapLen + + //new param bound flag + if data[pos] == 1 { + pos++ + if len(data) < (pos + (paramNum << 1)) { + return nil, ErrMalformPacket + } + + paramTypes = data[pos : pos+(paramNum<<1)] + pos += (paramNum << 1) + + paramValues = data[pos:] + } + + if err := c.bindStmtArgs(s, nullBitmaps, paramTypes, paramValues); err != nil { + return nil, errors.Trace(err) + } + } + + var r *Result + var err error + if r, err = c.h.HandleStmtExecute(s.Context, s.Query, s.Args); err != nil { + return nil, errors.Trace(err) + } + + s.ResetParams() + + return r, nil +} + +func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error { + args := s.Args + + pos := 0 + + var v []byte + var n int = 0 + var isNull bool + var err error + + for i := 0; i < s.Params; i++ { + if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { + args[i] = nil + continue + } + + tp := paramTypes[i<<1] + isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 + + switch tp { + case MYSQL_TYPE_NULL: + args[i] = nil + continue + + case MYSQL_TYPE_TINY: + if len(paramValues) < (pos + 1) { + return ErrMalformPacket + } + + if isUnsigned { + args[i] = uint8(paramValues[pos]) + } else { + args[i] = int8(paramValues[pos]) + } + + pos++ + continue + + case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: + if len(paramValues) < (pos + 2) { + return ErrMalformPacket + } + + if isUnsigned { + args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) + } else { + args[i] = int16((binary.LittleEndian.Uint16(paramValues[pos : pos+2]))) + } + pos += 2 + continue + + case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG: + if len(paramValues) < (pos + 4) { + return ErrMalformPacket + } + + if isUnsigned { + args[i] = uint32(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) + } else { + args[i] = int32(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) + } + pos += 4 + continue + + case MYSQL_TYPE_LONGLONG: + if len(paramValues) < (pos + 8) { + return ErrMalformPacket + } + + if isUnsigned { + args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8]) + } else { + args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) + } + pos += 8 + continue + + case MYSQL_TYPE_FLOAT: + if len(paramValues) < (pos + 4) { + return ErrMalformPacket + } + + args[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))) + pos += 4 + continue + + case MYSQL_TYPE_DOUBLE: + if len(paramValues) < (pos + 8) { + return ErrMalformPacket + } + + args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) + pos += 8 + continue + + case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR, + MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, + MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, + MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY, + MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE, + MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME: + if len(paramValues) < (pos + 1) { + return ErrMalformPacket + } + + v, isNull, n, err = LengthEnodedString(paramValues[pos:]) + pos += n + if err != nil { + return errors.Trace(err) + } + + if !isNull { + args[i] = v + continue + } else { + args[i] = nil + continue + } + default: + return errors.Errorf("Stmt Unknown FieldType %d", tp) + } + } + return nil +} + +// stmt send long data command has no repsonse +func (c *Conn) handleStmtSendLongData(data []byte) error { + if len(data) < 6 { + return nil + } + + id := binary.LittleEndian.Uint32(data[0:4]) + + s, ok := c.stmts[id] + if !ok { + return nil + } + + paramId := binary.LittleEndian.Uint16(data[4:6]) + if paramId >= uint16(s.Params) { + return nil + } + + if s.Args[paramId] == nil { + s.Args[paramId] = data[6:] + } else { + if b, ok := s.Args[paramId].([]byte); ok { + b = append(b, data[6:]...) + s.Args[paramId] = b + } else { + return nil + } + } + + return nil +} + +func (c *Conn) handleStmtReset(data []byte) (*Result, error) { + if len(data) < 4 { + return nil, ErrMalformPacket + } + + id := binary.LittleEndian.Uint32(data[0:4]) + + s, ok := c.stmts[id] + if !ok { + return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, + strconv.FormatUint(uint64(id), 10), "stmt_reset") + } + + s.ResetParams() + + return &Result{}, nil +} + +// stmt close command has no repsonse +func (c *Conn) handleStmtClose(data []byte) error { + if len(data) < 4 { + return nil + } + + id := binary.LittleEndian.Uint32(data[0:4]) + + stmt, ok := c.stmts[id] + if !ok { + return nil + } + + if err := c.h.HandleStmtClose(stmt.Context); err != nil { + return err + } + + delete(c.stmts, id) + + return nil +} diff --git a/vendor/github.com/siddontang/go-mysql/vitess_license b/vendor/github.com/siddontang/go-mysql/vitess_license new file mode 100644 index 0000000..989d02e --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/vitess_license @@ -0,0 +1,28 @@ +Copyright 2012, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. From 2d5f8398d6bf46ce93e511ca179aceb32cd6073e Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 12 Feb 2017 13:15:53 +0200 Subject: [PATCH 095/104] gomysql_reader.go adapted to api changes in go-mysql --- go/binlog/gomysql_reader.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index 7b57a99..445617a 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -16,6 +16,7 @@ import ( "github.com/outbrain/golib/log" gomysql "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/replication" + "golang.org/x/net/context" ) type GoMySQLReader struct { @@ -39,7 +40,16 @@ func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *G } serverId := uint32(binlogReader.MigrationContext.ReplicaServerId) - binlogReader.binlogSyncer = replication.NewBinlogSyncer(serverId, "mysql") + + binlogSyncerConfig := &replication.BinlogSyncerConfig{ + ServerID: serverId, + Flavor: "mysql", + Host: connectionConfig.Key.Hostname, + Port: uint16(connectionConfig.Key.Port), + User: connectionConfig.User, + Password: connectionConfig.Password, + } + binlogReader.binlogSyncer = replication.NewBinlogSyncer(binlogSyncerConfig) return binlogReader, err } @@ -49,10 +59,6 @@ func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordin if coordinates.IsEmpty() { return log.Errorf("Emptry coordinates at ConnectBinlogStreamer()") } - log.Infof("Registering replica at %+v:%+v", this.connectionConfig.Key.Hostname, uint16(this.connectionConfig.Key.Port)) - if err := this.binlogSyncer.RegisterSlave(this.connectionConfig.Key.Hostname, uint16(this.connectionConfig.Key.Port), this.connectionConfig.User, this.connectionConfig.Password); err != nil { - return err - } this.currentCoordinates = coordinates log.Infof("Connecting binlog streamer at %+v", this.currentCoordinates) @@ -126,7 +132,7 @@ func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesCha if canStopStreaming() { break } - ev, err := this.binlogStreamer.GetEvent() + ev, err := this.binlogStreamer.GetEvent(context.Background()) if err != nil { return err } From 350eab79bb5e7213a814492087af895779cb6e4d Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 12 Feb 2017 13:17:56 +0200 Subject: [PATCH 096/104] adding vendor/github.com/ngaut/ dependency --- vendor/github.com/ngaut/deadline/rw.go | 50 ++ vendor/github.com/ngaut/log/LICENSE | 165 ++++++ vendor/github.com/ngaut/log/README.md | 2 + vendor/github.com/ngaut/log/crash_unix.go | 18 + vendor/github.com/ngaut/log/crash_win.go | 37 ++ vendor/github.com/ngaut/log/log.go | 380 ++++++++++++++ vendor/github.com/ngaut/log/log_test.go | 197 +++++++ vendor/github.com/ngaut/pools/id_pool.go | 72 +++ vendor/github.com/ngaut/pools/id_pool_test.go | 118 +++++ vendor/github.com/ngaut/pools/numbered.go | 149 ++++++ .../github.com/ngaut/pools/numbered_test.go | 84 +++ .../github.com/ngaut/pools/resource_pool.go | 228 ++++++++ .../ngaut/pools/resource_pool_test.go | 487 ++++++++++++++++++ vendor/github.com/ngaut/pools/roundrobin.go | 214 ++++++++ .../github.com/ngaut/pools/roundrobin_test.go | 126 +++++ vendor/github.com/ngaut/pools/vitess_license | 28 + vendor/github.com/ngaut/sync2/atomic.go | 114 ++++ vendor/github.com/ngaut/sync2/atomic_test.go | 32 ++ vendor/github.com/ngaut/sync2/cond.go | 56 ++ vendor/github.com/ngaut/sync2/cond_test.go | 276 ++++++++++ vendor/github.com/ngaut/sync2/semaphore.go | 55 ++ .../github.com/ngaut/sync2/semaphore_test.go | 41 ++ .../github.com/ngaut/sync2/service_manager.go | 121 +++++ .../ngaut/sync2/service_manager_test.go | 176 +++++++ vendor/github.com/ngaut/sync2/vitess_license | 28 + vendor/github.com/ngaut/systimemon/LICENSE | 201 ++++++++ vendor/github.com/ngaut/systimemon/README.md | 2 + .../ngaut/systimemon/systime_mon.go | 24 + .../ngaut/systimemon/systime_mon_test.go | 28 + 29 files changed, 3509 insertions(+) create mode 100644 vendor/github.com/ngaut/deadline/rw.go create mode 100644 vendor/github.com/ngaut/log/LICENSE create mode 100644 vendor/github.com/ngaut/log/README.md create mode 100644 vendor/github.com/ngaut/log/crash_unix.go create mode 100644 vendor/github.com/ngaut/log/crash_win.go create mode 100644 vendor/github.com/ngaut/log/log.go create mode 100644 vendor/github.com/ngaut/log/log_test.go create mode 100644 vendor/github.com/ngaut/pools/id_pool.go create mode 100644 vendor/github.com/ngaut/pools/id_pool_test.go create mode 100644 vendor/github.com/ngaut/pools/numbered.go create mode 100644 vendor/github.com/ngaut/pools/numbered_test.go create mode 100644 vendor/github.com/ngaut/pools/resource_pool.go create mode 100644 vendor/github.com/ngaut/pools/resource_pool_test.go create mode 100644 vendor/github.com/ngaut/pools/roundrobin.go create mode 100644 vendor/github.com/ngaut/pools/roundrobin_test.go create mode 100644 vendor/github.com/ngaut/pools/vitess_license create mode 100644 vendor/github.com/ngaut/sync2/atomic.go create mode 100644 vendor/github.com/ngaut/sync2/atomic_test.go create mode 100644 vendor/github.com/ngaut/sync2/cond.go create mode 100644 vendor/github.com/ngaut/sync2/cond_test.go create mode 100644 vendor/github.com/ngaut/sync2/semaphore.go create mode 100644 vendor/github.com/ngaut/sync2/semaphore_test.go create mode 100644 vendor/github.com/ngaut/sync2/service_manager.go create mode 100644 vendor/github.com/ngaut/sync2/service_manager_test.go create mode 100644 vendor/github.com/ngaut/sync2/vitess_license create mode 100644 vendor/github.com/ngaut/systimemon/LICENSE create mode 100644 vendor/github.com/ngaut/systimemon/README.md create mode 100644 vendor/github.com/ngaut/systimemon/systime_mon.go create mode 100644 vendor/github.com/ngaut/systimemon/systime_mon_test.go diff --git a/vendor/github.com/ngaut/deadline/rw.go b/vendor/github.com/ngaut/deadline/rw.go new file mode 100644 index 0000000..19d4368 --- /dev/null +++ b/vendor/github.com/ngaut/deadline/rw.go @@ -0,0 +1,50 @@ +package deadline + +import ( + "io" + "time" +) + +type DeadlineReader interface { + io.Reader + SetReadDeadline(t time.Time) error +} + +type DeadlineWriter interface { + io.Writer + SetWriteDeadline(t time.Time) error +} + +type DeadlineReadWriter interface { + io.ReadWriter + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error +} + +type deadlineReader struct { + DeadlineReader + timeout time.Duration +} + +func (r *deadlineReader) Read(p []byte) (int, error) { + r.DeadlineReader.SetReadDeadline(time.Now().Add(r.timeout)) + return r.DeadlineReader.Read(p) +} + +func NewDeadlineReader(r DeadlineReader, timeout time.Duration) io.Reader { + return &deadlineReader{DeadlineReader: r, timeout: timeout} +} + +type deadlineWriter struct { + DeadlineWriter + timeout time.Duration +} + +func (r *deadlineWriter) Write(p []byte) (int, error) { + r.DeadlineWriter.SetWriteDeadline(time.Now().Add(r.timeout)) + return r.DeadlineWriter.Write(p) +} + +func NewDeadlineWriter(r DeadlineWriter, timeout time.Duration) io.Writer { + return &deadlineWriter{DeadlineWriter: r, timeout: timeout} +} diff --git a/vendor/github.com/ngaut/log/LICENSE b/vendor/github.com/ngaut/log/LICENSE new file mode 100644 index 0000000..6600f1c --- /dev/null +++ b/vendor/github.com/ngaut/log/LICENSE @@ -0,0 +1,165 @@ +GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/vendor/github.com/ngaut/log/README.md b/vendor/github.com/ngaut/log/README.md new file mode 100644 index 0000000..e0e857e --- /dev/null +++ b/vendor/github.com/ngaut/log/README.md @@ -0,0 +1,2 @@ +logging +======= diff --git a/vendor/github.com/ngaut/log/crash_unix.go b/vendor/github.com/ngaut/log/crash_unix.go new file mode 100644 index 0000000..37f407d --- /dev/null +++ b/vendor/github.com/ngaut/log/crash_unix.go @@ -0,0 +1,18 @@ +// +build freebsd openbsd netbsd dragonfly darwin linux + +package log + +import ( + "log" + "os" + "syscall" +) + +func CrashLog(file string) { + f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.Println(err.Error()) + } else { + syscall.Dup2(int(f.Fd()), 2) + } +} diff --git a/vendor/github.com/ngaut/log/crash_win.go b/vendor/github.com/ngaut/log/crash_win.go new file mode 100644 index 0000000..7d612ee --- /dev/null +++ b/vendor/github.com/ngaut/log/crash_win.go @@ -0,0 +1,37 @@ +// +build windows + +package log + +import ( + "log" + "os" + "syscall" +) + +var ( + kernel32 = syscall.MustLoadDLL("kernel32.dll") + procSetStdHandle = kernel32.MustFindProc("SetStdHandle") +) + +func setStdHandle(stdhandle int32, handle syscall.Handle) error { + r0, _, e1 := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdhandle), uintptr(handle), 0) + if r0 == 0 { + if e1 != 0 { + return error(e1) + } + return syscall.EINVAL + } + return nil +} + +func CrashLog(file string) { + f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.Println(err.Error()) + } else { + err = setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(f.Fd())) + if err != nil { + log.Println(err.Error()) + } + } +} diff --git a/vendor/github.com/ngaut/log/log.go b/vendor/github.com/ngaut/log/log.go new file mode 100644 index 0000000..896b393 --- /dev/null +++ b/vendor/github.com/ngaut/log/log.go @@ -0,0 +1,380 @@ +//high level log wrapper, so it can output different log based on level +package log + +import ( + "fmt" + "io" + "log" + "os" + "runtime" + "sync" + "time" +) + +const ( + Ldate = log.Ldate + Llongfile = log.Llongfile + Lmicroseconds = log.Lmicroseconds + Lshortfile = log.Lshortfile + LstdFlags = log.LstdFlags + Ltime = log.Ltime +) + +type ( + LogLevel int + LogType int +) + +const ( + LOG_FATAL = LogType(0x1) + LOG_ERROR = LogType(0x2) + LOG_WARNING = LogType(0x4) + LOG_INFO = LogType(0x8) + LOG_DEBUG = LogType(0x10) +) + +const ( + LOG_LEVEL_NONE = LogLevel(0x0) + LOG_LEVEL_FATAL = LOG_LEVEL_NONE | LogLevel(LOG_FATAL) + LOG_LEVEL_ERROR = LOG_LEVEL_FATAL | LogLevel(LOG_ERROR) + LOG_LEVEL_WARN = LOG_LEVEL_ERROR | LogLevel(LOG_WARNING) + LOG_LEVEL_INFO = LOG_LEVEL_WARN | LogLevel(LOG_INFO) + LOG_LEVEL_DEBUG = LOG_LEVEL_INFO | LogLevel(LOG_DEBUG) + LOG_LEVEL_ALL = LOG_LEVEL_DEBUG +) + +const FORMAT_TIME_DAY string = "20060102" +const FORMAT_TIME_HOUR string = "2006010215" + +var _log *logger = New() + +func init() { + SetFlags(Ldate | Ltime | Lshortfile) + SetHighlighting(runtime.GOOS != "windows") +} + +func Logger() *log.Logger { + return _log._log +} + +func SetLevel(level LogLevel) { + _log.SetLevel(level) +} +func GetLogLevel() LogLevel { + return _log.level +} + +func SetOutput(out io.Writer) { + _log.SetOutput(out) +} + +func SetOutputByName(path string) error { + return _log.SetOutputByName(path) +} + +func SetFlags(flags int) { + _log._log.SetFlags(flags) +} + +func Info(v ...interface{}) { + _log.Info(v...) +} + +func Infof(format string, v ...interface{}) { + _log.Infof(format, v...) +} + +func Debug(v ...interface{}) { + _log.Debug(v...) +} + +func Debugf(format string, v ...interface{}) { + _log.Debugf(format, v...) +} + +func Warn(v ...interface{}) { + _log.Warning(v...) +} + +func Warnf(format string, v ...interface{}) { + _log.Warningf(format, v...) +} + +func Warning(v ...interface{}) { + _log.Warning(v...) +} + +func Warningf(format string, v ...interface{}) { + _log.Warningf(format, v...) +} + +func Error(v ...interface{}) { + _log.Error(v...) +} + +func Errorf(format string, v ...interface{}) { + _log.Errorf(format, v...) +} + +func Fatal(v ...interface{}) { + _log.Fatal(v...) +} + +func Fatalf(format string, v ...interface{}) { + _log.Fatalf(format, v...) +} + +func SetLevelByString(level string) { + _log.SetLevelByString(level) +} + +func SetHighlighting(highlighting bool) { + _log.SetHighlighting(highlighting) +} + +func SetRotateByDay() { + _log.SetRotateByDay() +} + +func SetRotateByHour() { + _log.SetRotateByHour() +} + +type logger struct { + _log *log.Logger + level LogLevel + highlighting bool + + dailyRolling bool + hourRolling bool + + fileName string + logSuffix string + fd *os.File + + lock sync.Mutex +} + +func (l *logger) SetHighlighting(highlighting bool) { + l.highlighting = highlighting +} + +func (l *logger) SetLevel(level LogLevel) { + l.level = level +} + +func (l *logger) SetLevelByString(level string) { + l.level = StringToLogLevel(level) +} + +func (l *logger) SetRotateByDay() { + l.dailyRolling = true + l.logSuffix = genDayTime(time.Now()) +} + +func (l *logger) SetRotateByHour() { + l.hourRolling = true + l.logSuffix = genHourTime(time.Now()) +} + +func (l *logger) rotate() error { + l.lock.Lock() + defer l.lock.Unlock() + + var suffix string + if l.dailyRolling { + suffix = genDayTime(time.Now()) + } else if l.hourRolling { + suffix = genHourTime(time.Now()) + } else { + return nil + } + + // Notice: if suffix is not equal to l.LogSuffix, then rotate + if suffix != l.logSuffix { + err := l.doRotate(suffix) + if err != nil { + return err + } + } + + return nil +} + +func (l *logger) doRotate(suffix string) error { + // Notice: Not check error, is this ok? + l.fd.Close() + + lastFileName := l.fileName + "." + l.logSuffix + err := os.Rename(l.fileName, lastFileName) + if err != nil { + return err + } + + err = l.SetOutputByName(l.fileName) + if err != nil { + return err + } + + l.logSuffix = suffix + + return nil +} + +func (l *logger) SetOutput(out io.Writer) { + l._log = log.New(out, l._log.Prefix(), l._log.Flags()) +} + +func (l *logger) SetOutputByName(path string) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666) + if err != nil { + log.Fatal(err) + } + + l.SetOutput(f) + + l.fileName = path + l.fd = f + + return err +} + +func (l *logger) log(t LogType, v ...interface{}) { + if l.level|LogLevel(t) != l.level { + return + } + + err := l.rotate() + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return + } + + v1 := make([]interface{}, len(v)+2) + logStr, logColor := LogTypeToString(t) + if l.highlighting { + v1[0] = "\033" + logColor + "m[" + logStr + "]" + copy(v1[1:], v) + v1[len(v)+1] = "\033[0m" + } else { + v1[0] = "[" + logStr + "]" + copy(v1[1:], v) + v1[len(v)+1] = "" + } + + s := fmt.Sprintln(v1...) + l._log.Output(4, s) +} + +func (l *logger) logf(t LogType, format string, v ...interface{}) { + if l.level|LogLevel(t) != l.level { + return + } + + err := l.rotate() + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return + } + + logStr, logColor := LogTypeToString(t) + var s string + if l.highlighting { + s = "\033" + logColor + "m[" + logStr + "] " + fmt.Sprintf(format, v...) + "\033[0m" + } else { + s = "[" + logStr + "] " + fmt.Sprintf(format, v...) + } + l._log.Output(4, s) +} + +func (l *logger) Fatal(v ...interface{}) { + l.log(LOG_FATAL, v...) + os.Exit(-1) +} + +func (l *logger) Fatalf(format string, v ...interface{}) { + l.logf(LOG_FATAL, format, v...) + os.Exit(-1) +} + +func (l *logger) Error(v ...interface{}) { + l.log(LOG_ERROR, v...) +} + +func (l *logger) Errorf(format string, v ...interface{}) { + l.logf(LOG_ERROR, format, v...) +} + +func (l *logger) Warning(v ...interface{}) { + l.log(LOG_WARNING, v...) +} + +func (l *logger) Warningf(format string, v ...interface{}) { + l.logf(LOG_WARNING, format, v...) +} + +func (l *logger) Debug(v ...interface{}) { + l.log(LOG_DEBUG, v...) +} + +func (l *logger) Debugf(format string, v ...interface{}) { + l.logf(LOG_DEBUG, format, v...) +} + +func (l *logger) Info(v ...interface{}) { + l.log(LOG_INFO, v...) +} + +func (l *logger) Infof(format string, v ...interface{}) { + l.logf(LOG_INFO, format, v...) +} + +func StringToLogLevel(level string) LogLevel { + switch level { + case "fatal": + return LOG_LEVEL_FATAL + case "error": + return LOG_LEVEL_ERROR + case "warn": + return LOG_LEVEL_WARN + case "warning": + return LOG_LEVEL_WARN + case "debug": + return LOG_LEVEL_DEBUG + case "info": + return LOG_LEVEL_INFO + } + return LOG_LEVEL_ALL +} + +func LogTypeToString(t LogType) (string, string) { + switch t { + case LOG_FATAL: + return "fatal", "[0;31" + case LOG_ERROR: + return "error", "[0;31" + case LOG_WARNING: + return "warning", "[0;33" + case LOG_DEBUG: + return "debug", "[0;36" + case LOG_INFO: + return "info", "[0;37" + } + return "unknown", "[0;37" +} + +func genDayTime(t time.Time) string { + return t.Format(FORMAT_TIME_DAY) +} + +func genHourTime(t time.Time) string { + return t.Format(FORMAT_TIME_HOUR) +} + +func New() *logger { + return Newlogger(os.Stderr, "") +} + +func Newlogger(w io.Writer, prefix string) *logger { + return &logger{_log: log.New(w, prefix, LstdFlags), level: LOG_LEVEL_ALL, highlighting: true} +} diff --git a/vendor/github.com/ngaut/log/log_test.go b/vendor/github.com/ngaut/log/log_test.go new file mode 100644 index 0000000..adbcfb3 --- /dev/null +++ b/vendor/github.com/ngaut/log/log_test.go @@ -0,0 +1,197 @@ +package log + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" + "testing" + "time" +) + +func isFileExists(name string) bool { + f, err := os.Stat(name) + if err != nil { + if os.IsNotExist(err) { + return false + } + } + + if f.IsDir() { + return false + } + + return true +} + +func parseDate(value string, format string) (time.Time, error) { + tt, err := time.ParseInLocation(format, value, time.Local) + if err != nil { + fmt.Println("[Error]" + err.Error()) + return tt, err + } + + return tt, nil +} + +func checkLogData(fileName string, containData string, num int64) error { + input, err := os.OpenFile(fileName, os.O_RDONLY, 0) + if err != nil { + return err + } + defer input.Close() + + var lineNum int64 + br := bufio.NewReader(input) + for { + line, err := br.ReadString('\n') + if err == io.EOF { + break + } + + realLine := strings.TrimRight(line, "\n") + if strings.Contains(realLine, containData) { + lineNum += 1 + } + } + + // check whether num is equal to lineNum + if lineNum != num { + return fmt.Errorf("checkLogData fail - %d vs %d", lineNum, num) + } + + return nil +} + +func TestDayRotateCase(t *testing.T) { + _log = New() + + logName := "example_day_test.log" + if isFileExists(logName) { + err := os.Remove(logName) + if err != nil { + t.Errorf("Remove old log file fail - %s, %s\n", err.Error(), logName) + } + } + + SetRotateByDay() + err := SetOutputByName(logName) + if err != nil { + t.Errorf("SetOutputByName fail - %s, %s\n", err.Error(), logName) + } + + if _log.logSuffix == "" { + t.Errorf("bad log suffix fail - %s\n", _log.logSuffix) + } + + day, err := parseDate(_log.logSuffix, FORMAT_TIME_DAY) + if err != nil { + t.Errorf("parseDate fail - %s, %s\n", err.Error(), _log.logSuffix) + } + + _log.Info("Test data") + _log.Infof("Test data - %s", day.String()) + + // mock log suffix to check rotate + lastDay := day.AddDate(0, 0, -1) + _log.logSuffix = genDayTime(lastDay) + oldLogSuffix := _log.logSuffix + + _log.Info("Test new data") + _log.Infof("Test new data - %s", day.String()) + + err = _log.fd.Close() + if err != nil { + t.Errorf("close log fd fail - %s, %s\n", err.Error(), _log.fileName) + } + + // check both old and new log file datas + oldLogName := logName + "." + oldLogSuffix + err = checkLogData(oldLogName, "Test data", 2) + if err != nil { + t.Errorf("old log file checkLogData fail - %s, %s\n", err.Error(), oldLogName) + } + + err = checkLogData(logName, "Test new data", 2) + if err != nil { + t.Errorf("new log file checkLogData fail - %s, %s\n", err.Error(), logName) + } + + // remove test log files + err = os.Remove(oldLogName) + if err != nil { + t.Errorf("Remove final old log file fail - %s, %s\n", err.Error(), oldLogName) + } + + err = os.Remove(logName) + if err != nil { + t.Errorf("Remove final new log file fail - %s, %s\n", err.Error(), logName) + } +} + +func TestHourRotateCase(t *testing.T) { + _log = New() + + logName := "example_hour_test.log" + if isFileExists(logName) { + err := os.Remove(logName) + if err != nil { + t.Errorf("Remove old log file fail - %s, %s\n", err.Error(), logName) + } + } + + SetRotateByHour() + err := SetOutputByName(logName) + if err != nil { + t.Errorf("SetOutputByName fail - %s, %s\n", err.Error(), logName) + } + + if _log.logSuffix == "" { + t.Errorf("bad log suffix fail - %s\n", _log.logSuffix) + } + + hour, err := parseDate(_log.logSuffix, FORMAT_TIME_HOUR) + if err != nil { + t.Errorf("parseDate fail - %s, %s\n", err.Error(), _log.logSuffix) + } + + _log.Info("Test data") + _log.Infof("Test data - %s", hour.String()) + + // mock log suffix to check rotate + lastHour := hour.Add(time.Duration(-1 * time.Hour)) + _log.logSuffix = genHourTime(lastHour) + oldLogSuffix := _log.logSuffix + + _log.Info("Test new data") + _log.Infof("Test new data - %s", hour.String()) + + err = _log.fd.Close() + if err != nil { + t.Errorf("close log fd fail - %s, %s\n", err.Error(), _log.fileName) + } + + // check both old and new log file datas + oldLogName := logName + "." + oldLogSuffix + err = checkLogData(oldLogName, "Test data", 2) + if err != nil { + t.Errorf("old log file checkLogData fail - %s, %s\n", err.Error(), oldLogName) + } + + err = checkLogData(logName, "Test new data", 2) + if err != nil { + t.Errorf("new log file checkLogData fail - %s, %s\n", err.Error(), logName) + } + + // remove test log files + err = os.Remove(oldLogName) + if err != nil { + t.Errorf("Remove final old log file fail - %s, %s\n", err.Error(), oldLogName) + } + + err = os.Remove(logName) + if err != nil { + t.Errorf("Remove final new log file fail - %s, %s\n", err.Error(), logName) + } +} diff --git a/vendor/github.com/ngaut/pools/id_pool.go b/vendor/github.com/ngaut/pools/id_pool.go new file mode 100644 index 0000000..31db606 --- /dev/null +++ b/vendor/github.com/ngaut/pools/id_pool.go @@ -0,0 +1,72 @@ +// Copyright 2014, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pools + +import ( + "fmt" + "sync" +) + +// IDPool is used to ensure that the set of IDs in use concurrently never +// contains any duplicates. The IDs start at 1 and increase without bound, but +// will never be larger than the peak number of concurrent uses. +// +// IDPool's Get() and Set() methods can be used concurrently. +type IDPool struct { + sync.Mutex + + // used holds the set of values that have been returned to us with Put(). + used map[uint32]bool + // maxUsed remembers the largest value we've given out. + maxUsed uint32 +} + +// NewIDPool creates and initializes an IDPool. +func NewIDPool() *IDPool { + return &IDPool{ + used: make(map[uint32]bool), + } +} + +// Get returns an ID that is unique among currently active users of this pool. +func (pool *IDPool) Get() (id uint32) { + pool.Lock() + defer pool.Unlock() + + // Pick a value that's been returned, if any. + for key, _ := range pool.used { + delete(pool.used, key) + return key + } + + // No recycled IDs are available, so increase the pool size. + pool.maxUsed += 1 + return pool.maxUsed +} + +// Put recycles an ID back into the pool for others to use. Putting back a value +// or 0, or a value that is not currently "checked out", will result in a panic +// because that should never happen except in the case of a programming error. +func (pool *IDPool) Put(id uint32) { + pool.Lock() + defer pool.Unlock() + + if id < 1 || id > pool.maxUsed { + panic(fmt.Errorf("IDPool.Put(%v): invalid value, must be in the range [1,%v]", id, pool.maxUsed)) + } + + if pool.used[id] { + panic(fmt.Errorf("IDPool.Put(%v): can't put value that was already recycled", id)) + } + + // If we're recycling maxUsed, just shrink the pool. + if id == pool.maxUsed { + pool.maxUsed = id - 1 + return + } + + // Add it to the set of recycled IDs. + pool.used[id] = true +} diff --git a/vendor/github.com/ngaut/pools/id_pool_test.go b/vendor/github.com/ngaut/pools/id_pool_test.go new file mode 100644 index 0000000..a437874 --- /dev/null +++ b/vendor/github.com/ngaut/pools/id_pool_test.go @@ -0,0 +1,118 @@ +// Copyright 2014, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pools + +import ( + "reflect" + "strings" + "testing" +) + +func (pool *IDPool) want(want *IDPool, t *testing.T) { + if pool.maxUsed != want.maxUsed { + t.Errorf("pool.maxUsed = %#v, want %#v", pool.maxUsed, want.maxUsed) + } + + if !reflect.DeepEqual(pool.used, want.used) { + t.Errorf("pool.used = %#v, want %#v", pool.used, want.used) + } +} + +func TestIDPoolFirstGet(t *testing.T) { + pool := NewIDPool() + + if got := pool.Get(); got != 1 { + t.Errorf("pool.Get() = %v, want 1", got) + } + + pool.want(&IDPool{used: map[uint32]bool{}, maxUsed: 1}, t) +} + +func TestIDPoolSecondGet(t *testing.T) { + pool := NewIDPool() + pool.Get() + + if got := pool.Get(); got != 2 { + t.Errorf("pool.Get() = %v, want 2", got) + } + + pool.want(&IDPool{used: map[uint32]bool{}, maxUsed: 2}, t) +} + +func TestIDPoolPutToUsedSet(t *testing.T) { + pool := NewIDPool() + id1 := pool.Get() + pool.Get() + pool.Put(id1) + + pool.want(&IDPool{used: map[uint32]bool{1: true}, maxUsed: 2}, t) +} + +func TestIDPoolPutMaxUsed1(t *testing.T) { + pool := NewIDPool() + id1 := pool.Get() + pool.Put(id1) + + pool.want(&IDPool{used: map[uint32]bool{}, maxUsed: 0}, t) +} + +func TestIDPoolPutMaxUsed2(t *testing.T) { + pool := NewIDPool() + pool.Get() + id2 := pool.Get() + pool.Put(id2) + + pool.want(&IDPool{used: map[uint32]bool{}, maxUsed: 1}, t) +} + +func TestIDPoolGetFromUsedSet(t *testing.T) { + pool := NewIDPool() + id1 := pool.Get() + pool.Get() + pool.Put(id1) + + if got := pool.Get(); got != 1 { + t.Errorf("pool.Get() = %v, want 1", got) + } + + pool.want(&IDPool{used: map[uint32]bool{}, maxUsed: 2}, t) +} + +func wantError(want string, t *testing.T) { + rec := recover() + if rec == nil { + t.Errorf("expected panic, but there wasn't one") + } + err, ok := rec.(error) + if !ok || !strings.Contains(err.Error(), want) { + t.Errorf("wrong error, got '%v', want '%v'", err, want) + } +} + +func TestIDPoolPut0(t *testing.T) { + pool := NewIDPool() + pool.Get() + + defer wantError("invalid value", t) + pool.Put(0) +} + +func TestIDPoolPutInvalid(t *testing.T) { + pool := NewIDPool() + pool.Get() + + defer wantError("invalid value", t) + pool.Put(5) +} + +func TestIDPoolPutDuplicate(t *testing.T) { + pool := NewIDPool() + pool.Get() + pool.Get() + pool.Put(1) + + defer wantError("already recycled", t) + pool.Put(1) +} diff --git a/vendor/github.com/ngaut/pools/numbered.go b/vendor/github.com/ngaut/pools/numbered.go new file mode 100644 index 0000000..e170e03 --- /dev/null +++ b/vendor/github.com/ngaut/pools/numbered.go @@ -0,0 +1,149 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pools + +import ( + "fmt" + "sync" + "time" +) + +// Numbered allows you to manage resources by tracking them with numbers. +// There are no interface restrictions on what you can track. +type Numbered struct { + mu sync.Mutex + empty *sync.Cond // Broadcast when pool becomes empty + resources map[int64]*numberedWrapper +} + +type numberedWrapper struct { + val interface{} + inUse bool + purpose string + timeCreated time.Time + timeUsed time.Time +} + +func NewNumbered() *Numbered { + n := &Numbered{resources: make(map[int64]*numberedWrapper)} + n.empty = sync.NewCond(&n.mu) + return n +} + +// Register starts tracking a resource by the supplied id. +// It does not lock the object. +// It returns an error if the id already exists. +func (nu *Numbered) Register(id int64, val interface{}) error { + nu.mu.Lock() + defer nu.mu.Unlock() + if _, ok := nu.resources[id]; ok { + return fmt.Errorf("already present") + } + now := time.Now() + nu.resources[id] = &numberedWrapper{ + val: val, + timeCreated: now, + timeUsed: now, + } + return nil +} + +// Unregiester forgets the specified resource. +// If the resource is not present, it's ignored. +func (nu *Numbered) Unregister(id int64) { + nu.mu.Lock() + defer nu.mu.Unlock() + delete(nu.resources, id) + if len(nu.resources) == 0 { + nu.empty.Broadcast() + } +} + +// Get locks the resource for use. It accepts a purpose as a string. +// If it cannot be found, it returns a "not found" error. If in use, +// it returns a "in use: purpose" error. +func (nu *Numbered) Get(id int64, purpose string) (val interface{}, err error) { + nu.mu.Lock() + defer nu.mu.Unlock() + nw, ok := nu.resources[id] + if !ok { + return nil, fmt.Errorf("not found") + } + if nw.inUse { + return nil, fmt.Errorf("in use: %s", nw.purpose) + } + nw.inUse = true + nw.purpose = purpose + return nw.val, nil +} + +// Put unlocks a resource for someone else to use. +func (nu *Numbered) Put(id int64) { + nu.mu.Lock() + defer nu.mu.Unlock() + if nw, ok := nu.resources[id]; ok { + nw.inUse = false + nw.purpose = "" + nw.timeUsed = time.Now() + } +} + +// GetOutdated returns a list of resources that are older than age, and locks them. +// It does not return any resources that are already locked. +func (nu *Numbered) GetOutdated(age time.Duration, purpose string) (vals []interface{}) { + nu.mu.Lock() + defer nu.mu.Unlock() + now := time.Now() + for _, nw := range nu.resources { + if nw.inUse { + continue + } + if nw.timeCreated.Add(age).Sub(now) <= 0 { + nw.inUse = true + nw.purpose = purpose + vals = append(vals, nw.val) + } + } + return vals +} + +// GetIdle returns a list of resurces that have been idle for longer +// than timeout, and locks them. It does not return any resources that +// are already locked. +func (nu *Numbered) GetIdle(timeout time.Duration, purpose string) (vals []interface{}) { + nu.mu.Lock() + defer nu.mu.Unlock() + now := time.Now() + for _, nw := range nu.resources { + if nw.inUse { + continue + } + if nw.timeUsed.Add(timeout).Sub(now) <= 0 { + nw.inUse = true + nw.purpose = purpose + vals = append(vals, nw.val) + } + } + return vals +} + +// WaitForEmpty returns as soon as the pool becomes empty +func (nu *Numbered) WaitForEmpty() { + nu.mu.Lock() + defer nu.mu.Unlock() + for len(nu.resources) != 0 { + nu.empty.Wait() + } +} + +func (nu *Numbered) StatsJSON() string { + return fmt.Sprintf("{\"Size\": %v}", nu.Size()) +} + +func (nu *Numbered) Size() (size int64) { + nu.mu.Lock() + defer nu.mu.Unlock() + return int64(len(nu.resources)) +} diff --git a/vendor/github.com/ngaut/pools/numbered_test.go b/vendor/github.com/ngaut/pools/numbered_test.go new file mode 100644 index 0000000..54d5946 --- /dev/null +++ b/vendor/github.com/ngaut/pools/numbered_test.go @@ -0,0 +1,84 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pools + +import ( + "testing" + "time" +) + +func TestNumbered(t *testing.T) { + id := int64(0) + p := NewNumbered() + + var err error + if err = p.Register(id, id); err != nil { + t.Errorf("Error %v", err) + } + if err = p.Register(id, id); err.Error() != "already present" { + t.Errorf("want 'already present', got '%v'", err) + } + var v interface{} + if v, err = p.Get(id, "test"); err != nil { + t.Errorf("Error %v", err) + } + if v.(int64) != id { + t.Errorf("want %v, got %v", id, v.(int64)) + } + if v, err = p.Get(id, "test1"); err.Error() != "in use: test" { + t.Errorf("want 'in use: test', got '%v'", err) + } + p.Put(id) + if v, err = p.Get(1, "test2"); err.Error() != "not found" { + t.Errorf("want 'not found', got '%v'", err) + } + p.Unregister(1) // Should not fail + p.Unregister(0) + // p is now empty + + p.Register(id, id) + id++ + p.Register(id, id) + time.Sleep(300 * time.Millisecond) + id++ + p.Register(id, id) + time.Sleep(100 * time.Millisecond) + + // p has 0, 1, 2 (0 & 1 are aged) + vals := p.GetOutdated(200*time.Millisecond, "by outdated") + if len(vals) != 2 { + t.Errorf("want 2, got %v", len(vals)) + } + if v, err = p.Get(vals[0].(int64), "test1"); err.Error() != "in use: by outdated" { + t.Errorf("want 'in use: by outdated', got '%v'", err) + } + for _, v := range vals { + p.Put(v.(int64)) + } + time.Sleep(100 * time.Millisecond) + + // p has 0, 1, 2 (2 is idle) + vals = p.GetIdle(200*time.Millisecond, "by idle") + if len(vals) != 1 { + t.Errorf("want 1, got %v", len(vals)) + } + if v, err = p.Get(vals[0].(int64), "test1"); err.Error() != "in use: by idle" { + t.Errorf("want 'in use: by idle', got '%v'", err) + } + if vals[0].(int64) != 2 { + t.Errorf("want 2, got %v", vals[0]) + } + p.Unregister(vals[0].(int64)) + + // p has 0 & 1 + if p.Size() != 2 { + t.Errorf("want 2, got %v", p.Size()) + } + go func() { + p.Unregister(0) + p.Unregister(1) + }() + p.WaitForEmpty() +} diff --git a/vendor/github.com/ngaut/pools/resource_pool.go b/vendor/github.com/ngaut/pools/resource_pool.go new file mode 100644 index 0000000..b02cf04 --- /dev/null +++ b/vendor/github.com/ngaut/pools/resource_pool.go @@ -0,0 +1,228 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pools provides functionality to manage and reuse resources +// like connections. +package pools + +import ( + "fmt" + "time" + + "github.com/ngaut/sync2" +) + +var ( + CLOSED_ERR = fmt.Errorf("ResourcePool is closed") +) + +// Factory is a function that can be used to create a resource. +type Factory func() (Resource, error) + +// Every resource needs to suport the Resource interface. +// Thread synchronization between Close() and IsClosed() +// is the responsibility the caller. +type Resource interface { + Close() +} + +// ResourcePool allows you to use a pool of resources. +type ResourcePool struct { + resources chan resourceWrapper + factory Factory + capacity sync2.AtomicInt64 + idleTimeout sync2.AtomicDuration + + // stats + waitCount sync2.AtomicInt64 + waitTime sync2.AtomicDuration +} + +type resourceWrapper struct { + resource Resource + timeUsed time.Time +} + +// NewResourcePool creates a new ResourcePool pool. +// capacity is the initial capacity of the pool. +// maxCap is the maximum capacity. +// If a resource is unused beyond idleTimeout, it's discarded. +// An idleTimeout of 0 means that there is no timeout. +func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Duration) *ResourcePool { + if capacity <= 0 || maxCap <= 0 || capacity > maxCap { + panic(fmt.Errorf("Invalid/out of range capacity")) + } + rp := &ResourcePool{ + resources: make(chan resourceWrapper, maxCap), + factory: factory, + capacity: sync2.AtomicInt64(capacity), + idleTimeout: sync2.AtomicDuration(idleTimeout), + } + for i := 0; i < capacity; i++ { + rp.resources <- resourceWrapper{} + } + return rp +} + +// Close empties the pool calling Close on all its resources. +// You can call Close while there are outstanding resources. +// It waits for all resources to be returned (Put). +// After a Close, Get and TryGet are not allowed. +func (rp *ResourcePool) Close() { + rp.SetCapacity(0) +} + +func (rp *ResourcePool) IsClosed() (closed bool) { + return rp.capacity.Get() == 0 +} + +// Get will return the next available resource. If capacity +// has not been reached, it will create a new one using the factory. Otherwise, +// it will indefinitely wait till the next resource becomes available. +func (rp *ResourcePool) Get() (resource Resource, err error) { + return rp.get(true) +} + +// TryGet will return the next available resource. If none is available, and capacity +// has not been reached, it will create a new one using the factory. Otherwise, +// it will return nil with no error. +func (rp *ResourcePool) TryGet() (resource Resource, err error) { + return rp.get(false) +} + +func (rp *ResourcePool) get(wait bool) (resource Resource, err error) { + // Fetch + var wrapper resourceWrapper + var ok bool + select { + case wrapper, ok = <-rp.resources: + default: + if !wait { + return nil, nil + } + startTime := time.Now() + wrapper, ok = <-rp.resources + rp.recordWait(startTime) + } + if !ok { + return nil, CLOSED_ERR + } + + // Unwrap + timeout := rp.idleTimeout.Get() + if wrapper.resource != nil && timeout > 0 && wrapper.timeUsed.Add(timeout).Sub(time.Now()) < 0 { + wrapper.resource.Close() + wrapper.resource = nil + } + if wrapper.resource == nil { + wrapper.resource, err = rp.factory() + if err != nil { + rp.resources <- resourceWrapper{} + } + } + return wrapper.resource, err +} + +// Put will return a resource to the pool. For every successful Get, +// a corresponding Put is required. If you no longer need a resource, +// you will need to call Put(nil) instead of returning the closed resource. +// The will eventually cause a new resource to be created in its place. +func (rp *ResourcePool) Put(resource Resource) { + var wrapper resourceWrapper + if resource != nil { + wrapper = resourceWrapper{resource, time.Now()} + } + select { + case rp.resources <- wrapper: + default: + panic(fmt.Errorf("Attempt to Put into a full ResourcePool")) + } +} + +// SetCapacity changes the capacity of the pool. +// You can use it to shrink or expand, but not beyond +// the max capacity. If the change requires the pool +// to be shrunk, SetCapacity waits till the necessary +// number of resources are returned to the pool. +// A SetCapacity of 0 is equivalent to closing the ResourcePool. +func (rp *ResourcePool) SetCapacity(capacity int) error { + if capacity < 0 || capacity > cap(rp.resources) { + return fmt.Errorf("capacity %d is out of range", capacity) + } + + // Atomically swap new capacity with old, but only + // if old capacity is non-zero. + var oldcap int + for { + oldcap = int(rp.capacity.Get()) + if oldcap == 0 { + return CLOSED_ERR + } + if oldcap == capacity { + return nil + } + if rp.capacity.CompareAndSwap(int64(oldcap), int64(capacity)) { + break + } + } + + if capacity < oldcap { + for i := 0; i < oldcap-capacity; i++ { + wrapper := <-rp.resources + if wrapper.resource != nil { + wrapper.resource.Close() + } + } + } else { + for i := 0; i < capacity-oldcap; i++ { + rp.resources <- resourceWrapper{} + } + } + if capacity == 0 { + close(rp.resources) + } + return nil +} + +func (rp *ResourcePool) recordWait(start time.Time) { + rp.waitCount.Add(1) + rp.waitTime.Add(time.Now().Sub(start)) +} + +func (rp *ResourcePool) SetIdleTimeout(idleTimeout time.Duration) { + rp.idleTimeout.Set(idleTimeout) +} + +func (rp *ResourcePool) StatsJSON() string { + c, a, mx, wc, wt, it := rp.Stats() + return fmt.Sprintf(`{"Capacity": %v, "Available": %v, "MaxCapacity": %v, "WaitCount": %v, "WaitTime": %v, "IdleTimeout": %v}`, c, a, mx, wc, int64(wt), int64(it)) +} + +func (rp *ResourcePool) Stats() (capacity, available, maxCap, waitCount int64, waitTime, idleTimeout time.Duration) { + return rp.Capacity(), rp.Available(), rp.MaxCap(), rp.WaitCount(), rp.WaitTime(), rp.IdleTimeout() +} + +func (rp *ResourcePool) Capacity() int64 { + return rp.capacity.Get() +} + +func (rp *ResourcePool) Available() int64 { + return int64(len(rp.resources)) +} + +func (rp *ResourcePool) MaxCap() int64 { + return int64(cap(rp.resources)) +} + +func (rp *ResourcePool) WaitCount() int64 { + return rp.waitCount.Get() +} + +func (rp *ResourcePool) WaitTime() time.Duration { + return rp.waitTime.Get() +} + +func (rp *ResourcePool) IdleTimeout() time.Duration { + return rp.idleTimeout.Get() +} diff --git a/vendor/github.com/ngaut/pools/resource_pool_test.go b/vendor/github.com/ngaut/pools/resource_pool_test.go new file mode 100644 index 0000000..27820dd --- /dev/null +++ b/vendor/github.com/ngaut/pools/resource_pool_test.go @@ -0,0 +1,487 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pools + +import ( + "errors" + "testing" + "time" + + "github.com/ngaut/sync2" +) + +var lastId, count sync2.AtomicInt64 + +type TestResource struct { + num int64 + closed bool +} + +func (tr *TestResource) Close() { + if !tr.closed { + count.Add(-1) + tr.closed = true + } +} + +func (tr *TestResource) IsClosed() bool { + return tr.closed +} + +func PoolFactory() (Resource, error) { + count.Add(1) + return &TestResource{lastId.Add(1), false}, nil +} + +func FailFactory() (Resource, error) { + return nil, errors.New("Failed") +} + +func SlowFailFactory() (Resource, error) { + time.Sleep(10 * time.Nanosecond) + return nil, errors.New("Failed") +} + +func TestOpen(t *testing.T) { + lastId.Set(0) + count.Set(0) + p := NewResourcePool(PoolFactory, 6, 6, time.Second) + p.SetCapacity(5) + var resources [10]Resource + + // Test Get + for i := 0; i < 5; i++ { + r, err := p.Get() + resources[i] = r + if err != nil { + t.Errorf("Unexpected error %v", err) + } + _, available, _, waitCount, waitTime, _ := p.Stats() + if available != int64(5-i-1) { + t.Errorf("expecting %d, received %d", 5-i-1, available) + } + if waitCount != 0 { + t.Errorf("expecting 0, received %d", waitCount) + } + if waitTime != 0 { + t.Errorf("expecting 0, received %d", waitTime) + } + if lastId.Get() != int64(i+1) { + t.Errorf("Expecting %d, received %d", i+1, lastId.Get()) + } + if count.Get() != int64(i+1) { + t.Errorf("Expecting %d, received %d", i+1, count.Get()) + } + } + + // Test TryGet + r, err := p.TryGet() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + if r != nil { + t.Errorf("Expecting nil") + } + for i := 0; i < 5; i++ { + p.Put(resources[i]) + _, available, _, _, _, _ := p.Stats() + if available != int64(i+1) { + t.Errorf("expecting %d, received %d", 5-i-1, available) + } + } + for i := 0; i < 5; i++ { + r, err := p.TryGet() + resources[i] = r + if err != nil { + t.Errorf("Unexpected error %v", err) + } + if r == nil { + t.Errorf("Expecting non-nil") + } + if lastId.Get() != 5 { + t.Errorf("Expecting 5, received %d", lastId.Get()) + } + if count.Get() != 5 { + t.Errorf("Expecting 5, received %d", count.Get()) + } + } + + // Test that Get waits + ch := make(chan bool) + go func() { + for i := 0; i < 5; i++ { + r, err := p.Get() + if err != nil { + t.Errorf("Get failed: %v", err) + } + resources[i] = r + } + for i := 0; i < 5; i++ { + p.Put(resources[i]) + } + ch <- true + }() + for i := 0; i < 5; i++ { + // Sleep to ensure the goroutine waits + time.Sleep(10 * time.Nanosecond) + p.Put(resources[i]) + } + <-ch + _, _, _, waitCount, waitTime, _ := p.Stats() + if waitCount != 5 { + t.Errorf("Expecting 5, received %d", waitCount) + } + if waitTime == 0 { + t.Errorf("Expecting non-zero") + } + if lastId.Get() != 5 { + t.Errorf("Expecting 5, received %d", lastId.Get()) + } + + // Test Close resource + r, err = p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + r.Close() + p.Put(nil) + if count.Get() != 4 { + t.Errorf("Expecting 4, received %d", count.Get()) + } + for i := 0; i < 5; i++ { + r, err := p.Get() + if err != nil { + t.Errorf("Get failed: %v", err) + } + resources[i] = r + } + for i := 0; i < 5; i++ { + p.Put(resources[i]) + } + if count.Get() != 5 { + t.Errorf("Expecting 5, received %d", count.Get()) + } + if lastId.Get() != 6 { + t.Errorf("Expecting 6, received %d", lastId.Get()) + } + + // SetCapacity + p.SetCapacity(3) + if count.Get() != 3 { + t.Errorf("Expecting 3, received %d", count.Get()) + } + if lastId.Get() != 6 { + t.Errorf("Expecting 6, received %d", lastId.Get()) + } + capacity, available, _, _, _, _ := p.Stats() + if capacity != 3 { + t.Errorf("Expecting 3, received %d", capacity) + } + if available != 3 { + t.Errorf("Expecting 3, received %d", available) + } + p.SetCapacity(6) + capacity, available, _, _, _, _ = p.Stats() + if capacity != 6 { + t.Errorf("Expecting 6, received %d", capacity) + } + if available != 6 { + t.Errorf("Expecting 6, received %d", available) + } + for i := 0; i < 6; i++ { + r, err := p.Get() + if err != nil { + t.Errorf("Get failed: %v", err) + } + resources[i] = r + } + for i := 0; i < 6; i++ { + p.Put(resources[i]) + } + if count.Get() != 6 { + t.Errorf("Expecting 5, received %d", count.Get()) + } + if lastId.Get() != 9 { + t.Errorf("Expecting 9, received %d", lastId.Get()) + } + + // Close + p.Close() + capacity, available, _, _, _, _ = p.Stats() + if capacity != 0 { + t.Errorf("Expecting 0, received %d", capacity) + } + if available != 0 { + t.Errorf("Expecting 0, received %d", available) + } + if count.Get() != 0 { + t.Errorf("Expecting 0, received %d", count.Get()) + } +} + +func TestShrinking(t *testing.T) { + lastId.Set(0) + count.Set(0) + p := NewResourcePool(PoolFactory, 5, 5, time.Second) + var resources [10]Resource + // Leave one empty slot in the pool + for i := 0; i < 4; i++ { + r, err := p.Get() + if err != nil { + t.Errorf("Get failed: %v", err) + } + resources[i] = r + } + go p.SetCapacity(3) + time.Sleep(10 * time.Nanosecond) + stats := p.StatsJSON() + expected := `{"Capacity": 3, "Available": 0, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000}` + if stats != expected { + t.Errorf(`expecting '%s', received '%s'`, expected, stats) + } + + // TryGet is allowed when shrinking + r, err := p.TryGet() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + if r != nil { + t.Errorf("Expecting nil") + } + + // Get is allowed when shrinking, but it will wait + getdone := make(chan bool) + go func() { + r, err := p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + p.Put(r) + getdone <- true + }() + + // Put is allowed when shrinking. It's necessary. + for i := 0; i < 4; i++ { + p.Put(resources[i]) + } + // Wait for Get test to complete + <-getdone + stats = p.StatsJSON() + expected = `{"Capacity": 3, "Available": 3, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000}` + if stats != expected { + t.Errorf(`expecting '%s', received '%s'`, expected, stats) + } + if count.Get() != 3 { + t.Errorf("Expecting 3, received %d", count.Get()) + } + + // Ensure no deadlock if SetCapacity is called after we start + // waiting for a resource + for i := 0; i < 3; i++ { + resources[i], err = p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + } + // This will wait because pool is empty + go func() { + r, err := p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + p.Put(r) + getdone <- true + }() + time.Sleep(10 * time.Nanosecond) + + // This will wait till we Put + go p.SetCapacity(2) + time.Sleep(10 * time.Nanosecond) + + // This should not hang + for i := 0; i < 3; i++ { + p.Put(resources[i]) + } + <-getdone + capacity, available, _, _, _, _ := p.Stats() + if capacity != 2 { + t.Errorf("Expecting 2, received %d", capacity) + } + if available != 2 { + t.Errorf("Expecting 2, received %d", available) + } + if count.Get() != 2 { + t.Errorf("Expecting 2, received %d", count.Get()) + } + + // Test race condition of SetCapacity with itself + p.SetCapacity(3) + for i := 0; i < 3; i++ { + resources[i], err = p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + } + // This will wait because pool is empty + go func() { + r, err := p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + p.Put(r) + getdone <- true + }() + time.Sleep(10 * time.Nanosecond) + + // This will wait till we Put + go p.SetCapacity(2) + time.Sleep(10 * time.Nanosecond) + go p.SetCapacity(4) + time.Sleep(10 * time.Nanosecond) + + // This should not hang + for i := 0; i < 3; i++ { + p.Put(resources[i]) + } + <-getdone + + err = p.SetCapacity(-1) + if err == nil { + t.Errorf("Expecting error") + } + err = p.SetCapacity(255555) + if err == nil { + t.Errorf("Expecting error") + } + + capacity, available, _, _, _, _ = p.Stats() + if capacity != 4 { + t.Errorf("Expecting 4, received %d", capacity) + } + if available != 4 { + t.Errorf("Expecting 4, received %d", available) + } +} + +func TestClosing(t *testing.T) { + lastId.Set(0) + count.Set(0) + p := NewResourcePool(PoolFactory, 5, 5, time.Second) + var resources [10]Resource + for i := 0; i < 5; i++ { + r, err := p.Get() + if err != nil { + t.Errorf("Get failed: %v", err) + } + resources[i] = r + } + ch := make(chan bool) + go func() { + p.Close() + ch <- true + }() + + // Wait for goroutine to call Close + time.Sleep(10 * time.Nanosecond) + stats := p.StatsJSON() + expected := `{"Capacity": 0, "Available": 0, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000}` + if stats != expected { + t.Errorf(`expecting '%s', received '%s'`, expected, stats) + } + + // Put is allowed when closing + for i := 0; i < 5; i++ { + p.Put(resources[i]) + } + + // Wait for Close to return + <-ch + + // SetCapacity must be ignored after Close + err := p.SetCapacity(1) + if err == nil { + t.Errorf("expecting error") + } + + stats = p.StatsJSON() + expected = `{"Capacity": 0, "Available": 0, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000}` + if stats != expected { + t.Errorf(`expecting '%s', received '%s'`, expected, stats) + } + if lastId.Get() != 5 { + t.Errorf("Expecting 5, received %d", count.Get()) + } + if count.Get() != 0 { + t.Errorf("Expecting 0, received %d", count.Get()) + } +} + +func TestIdleTimeout(t *testing.T) { + lastId.Set(0) + count.Set(0) + p := NewResourcePool(PoolFactory, 1, 1, 10*time.Nanosecond) + defer p.Close() + + r, err := p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + p.Put(r) + if lastId.Get() != 1 { + t.Errorf("Expecting 1, received %d", count.Get()) + } + if count.Get() != 1 { + t.Errorf("Expecting 1, received %d", count.Get()) + } + time.Sleep(20 * time.Nanosecond) + r, err = p.Get() + if err != nil { + t.Errorf("Unexpected error %v", err) + } + if lastId.Get() != 2 { + t.Errorf("Expecting 2, received %d", count.Get()) + } + if count.Get() != 1 { + t.Errorf("Expecting 1, received %d", count.Get()) + } + p.Put(r) +} + +func TestCreateFail(t *testing.T) { + lastId.Set(0) + count.Set(0) + p := NewResourcePool(FailFactory, 5, 5, time.Second) + defer p.Close() + if _, err := p.Get(); err.Error() != "Failed" { + t.Errorf("Expecting Failed, received %v", err) + } + stats := p.StatsJSON() + expected := `{"Capacity": 5, "Available": 5, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000}` + if stats != expected { + t.Errorf(`expecting '%s', received '%s'`, expected, stats) + } +} + +func TestSlowCreateFail(t *testing.T) { + lastId.Set(0) + count.Set(0) + p := NewResourcePool(SlowFailFactory, 2, 2, time.Second) + defer p.Close() + ch := make(chan bool) + // The third Get should not wait indefinitely + for i := 0; i < 3; i++ { + go func() { + p.Get() + ch <- true + }() + } + for i := 0; i < 3; i++ { + <-ch + } + _, available, _, _, _, _ := p.Stats() + if available != 2 { + t.Errorf("Expecting 2, received %d", available) + } +} diff --git a/vendor/github.com/ngaut/pools/roundrobin.go b/vendor/github.com/ngaut/pools/roundrobin.go new file mode 100644 index 0000000..b06985f --- /dev/null +++ b/vendor/github.com/ngaut/pools/roundrobin.go @@ -0,0 +1,214 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pools + +import ( + "fmt" + "sync" + "time" +) + +// RoundRobin is deprecated. Use ResourcePool instead. +// RoundRobin allows you to use a pool of resources in a round robin fashion. +type RoundRobin struct { + mu sync.Mutex + available *sync.Cond + resources chan fifoWrapper + size int64 + factory Factory + idleTimeout time.Duration + + // stats + waitCount int64 + waitTime time.Duration +} + +type fifoWrapper struct { + resource Resource + timeUsed time.Time +} + +// NewRoundRobin creates a new RoundRobin pool. +// capacity is the maximum number of resources RoundRobin will create. +// factory will be the function used to create resources. +// If a resource is unused beyond idleTimeout, it's discarded. +func NewRoundRobin(capacity int, idleTimeout time.Duration) *RoundRobin { + r := &RoundRobin{ + resources: make(chan fifoWrapper, capacity), + size: 0, + idleTimeout: idleTimeout, + } + r.available = sync.NewCond(&r.mu) + return r +} + +// Open starts allowing the creation of resources +func (rr *RoundRobin) Open(factory Factory) { + rr.mu.Lock() + defer rr.mu.Unlock() + rr.factory = factory +} + +// Close empties the pool calling Close on all its resources. +// It waits for all resources to be returned (Put). +func (rr *RoundRobin) Close() { + rr.mu.Lock() + defer rr.mu.Unlock() + for rr.size > 0 { + select { + case fw := <-rr.resources: + go fw.resource.Close() + rr.size-- + default: + rr.available.Wait() + } + } + rr.factory = nil +} + +func (rr *RoundRobin) IsClosed() bool { + return rr.factory == nil +} + +// Get will return the next available resource. If none is available, and capacity +// has not been reached, it will create a new one using the factory. Otherwise, +// it will indefinitely wait till the next resource becomes available. +func (rr *RoundRobin) Get() (resource Resource, err error) { + return rr.get(true) +} + +// TryGet will return the next available resource. If none is available, and capacity +// has not been reached, it will create a new one using the factory. Otherwise, +// it will return nil with no error. +func (rr *RoundRobin) TryGet() (resource Resource, err error) { + return rr.get(false) +} + +func (rr *RoundRobin) get(wait bool) (resource Resource, err error) { + rr.mu.Lock() + defer rr.mu.Unlock() + // Any waits in this loop will release the lock, and it will be + // reacquired before the waits return. + for { + select { + case fw := <-rr.resources: + // Found a free resource in the channel + if rr.idleTimeout > 0 && fw.timeUsed.Add(rr.idleTimeout).Sub(time.Now()) < 0 { + // resource has been idle for too long. Discard & go for next. + go fw.resource.Close() + rr.size-- + // Nobody else should be waiting, but signal anyway. + rr.available.Signal() + continue + } + return fw.resource, nil + default: + // resource channel is empty + if rr.size >= int64(cap(rr.resources)) { + // The pool is full + if wait { + start := time.Now() + rr.available.Wait() + rr.recordWait(start) + continue + } + return nil, nil + } + // Pool is not full. Create a resource. + if resource, err = rr.waitForCreate(); err != nil { + // size was decremented, and somebody could be waiting. + rr.available.Signal() + return nil, err + } + // Creation successful. Account for this by incrementing size. + rr.size++ + return resource, err + } + } +} + +func (rr *RoundRobin) recordWait(start time.Time) { + rr.waitCount++ + rr.waitTime += time.Now().Sub(start) +} + +func (rr *RoundRobin) waitForCreate() (resource Resource, err error) { + // Prevent thundering herd: increment size before creating resource, and decrement after. + rr.size++ + rr.mu.Unlock() + defer func() { + rr.mu.Lock() + rr.size-- + }() + return rr.factory() +} + +// Put will return a resource to the pool. You MUST return every resource to the pool, +// even if it's closed. If a resource is closed, you should call Put(nil). +func (rr *RoundRobin) Put(resource Resource) { + rr.mu.Lock() + defer rr.available.Signal() + defer rr.mu.Unlock() + + if rr.size > int64(cap(rr.resources)) { + if resource != nil { + go resource.Close() + } + rr.size-- + } else if resource == nil { + rr.size-- + } else { + if len(rr.resources) == cap(rr.resources) { + panic("unexpected") + } + rr.resources <- fifoWrapper{resource, time.Now()} + } +} + +// Set capacity changes the capacity of the pool. +// You can use it to expand or shrink. +func (rr *RoundRobin) SetCapacity(capacity int) error { + rr.mu.Lock() + defer rr.available.Broadcast() + defer rr.mu.Unlock() + + nr := make(chan fifoWrapper, capacity) + // This loop transfers resources from the old channel + // to the new one, until it fills up or runs out. + // It discards extras, if any. + for { + select { + case fw := <-rr.resources: + if len(nr) < cap(nr) { + nr <- fw + } else { + go fw.resource.Close() + rr.size-- + } + continue + default: + } + break + } + rr.resources = nr + return nil +} + +func (rr *RoundRobin) SetIdleTimeout(idleTimeout time.Duration) { + rr.mu.Lock() + defer rr.mu.Unlock() + rr.idleTimeout = idleTimeout +} + +func (rr *RoundRobin) StatsJSON() string { + s, c, a, wc, wt, it := rr.Stats() + return fmt.Sprintf("{\"Size\": %v, \"Capacity\": %v, \"Available\": %v, \"WaitCount\": %v, \"WaitTime\": %v, \"IdleTimeout\": %v}", s, c, a, wc, int64(wt), int64(it)) +} + +func (rr *RoundRobin) Stats() (size, capacity, available, waitCount int64, waitTime, idleTimeout time.Duration) { + rr.mu.Lock() + defer rr.mu.Unlock() + return rr.size, int64(cap(rr.resources)), int64(len(rr.resources)), rr.waitCount, rr.waitTime, rr.idleTimeout +} diff --git a/vendor/github.com/ngaut/pools/roundrobin_test.go b/vendor/github.com/ngaut/pools/roundrobin_test.go new file mode 100644 index 0000000..870503d --- /dev/null +++ b/vendor/github.com/ngaut/pools/roundrobin_test.go @@ -0,0 +1,126 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pools + +import ( + "testing" + "time" +) + +func TestPool(t *testing.T) { + lastId.Set(0) + p := NewRoundRobin(5, time.Duration(10e9)) + p.Open(PoolFactory) + defer p.Close() + + for i := 0; i < 2; i++ { + r, err := p.TryGet() + if err != nil { + t.Errorf("TryGet failed: %v", err) + } + if r.(*TestResource).num != 1 { + t.Errorf("Expecting 1, received %d", r.(*TestResource).num) + } + p.Put(r) + } + // p = [1] + + all := make([]Resource, 5) + for i := 0; i < 5; i++ { + if all[i], _ = p.TryGet(); all[i] == nil { + t.Errorf("TryGet failed with nil") + } + } + // all = [1-5], p is empty + if none, _ := p.TryGet(); none != nil { + t.Errorf("TryGet failed with non-nil") + } + + ch := make(chan bool) + go ResourceWait(p, t, ch) + time.Sleep(1e8) + for i := 0; i < 5; i++ { + p.Put(all[i]) + } + // p = [1-5] + <-ch + // p = [1-5] + if p.waitCount != 1 { + t.Errorf("Expecting 1, received %d", p.waitCount) + } + + for i := 0; i < 5; i++ { + all[i], _ = p.Get() + } + // all = [1-5], p is empty + all[0].(*TestResource).Close() + p.Put(nil) + for i := 1; i < 5; i++ { + p.Put(all[i]) + } + // p = [2-5] + + for i := 0; i < 4; i++ { + r, _ := p.Get() + if r.(*TestResource).num != int64(i+2) { + t.Errorf("Expecting %d, received %d", i+2, r.(*TestResource).num) + } + p.Put(r) + } + + p.SetCapacity(3) + // p = [2-4] + if p.size != 3 { + t.Errorf("Expecting 3, received %d", p.size) + } + + p.SetIdleTimeout(time.Duration(1e8)) + time.Sleep(2e8) + r, _ := p.Get() + if r.(*TestResource).num != 6 { + t.Errorf("Expecting 6, received %d", r.(*TestResource).num) + } + p.Put(r) + // p = [6] +} + +func TestPoolFail(t *testing.T) { + p := NewRoundRobin(5, time.Duration(10e9)) + p.Open(FailFactory) + defer p.Close() + if _, err := p.Get(); err.Error() != "Failed" { + t.Errorf("Expecting Failed, received %v", err) + } +} + +func TestPoolFullFail(t *testing.T) { + p := NewRoundRobin(2, time.Duration(10e9)) + p.Open(SlowFailFactory) + defer p.Close() + ch := make(chan bool) + // The third get should not wait indefinitely + for i := 0; i < 3; i++ { + go func() { + p.Get() + ch <- true + }() + } + for i := 0; i < 3; i++ { + <-ch + } +} + +func ResourceWait(p *RoundRobin, t *testing.T, ch chan bool) { + for i := 0; i < 5; i++ { + if r, err := p.Get(); err != nil { + t.Errorf("TryGet failed: %v", err) + } else if r.(*TestResource).num != int64(i+1) { + t.Errorf("Expecting %d, received %d", i+1, r.(*TestResource).num) + } else { + p.Put(r) + } + } + ch <- true +} diff --git a/vendor/github.com/ngaut/pools/vitess_license b/vendor/github.com/ngaut/pools/vitess_license new file mode 100644 index 0000000..989d02e --- /dev/null +++ b/vendor/github.com/ngaut/pools/vitess_license @@ -0,0 +1,28 @@ +Copyright 2012, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/ngaut/sync2/atomic.go b/vendor/github.com/ngaut/sync2/atomic.go new file mode 100644 index 0000000..909f3b1 --- /dev/null +++ b/vendor/github.com/ngaut/sync2/atomic.go @@ -0,0 +1,114 @@ +// Copyright 2013, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "sync" + "sync/atomic" + "time" +) + +type AtomicInt32 int32 + +func (i *AtomicInt32) Add(n int32) int32 { + return atomic.AddInt32((*int32)(i), n) +} + +func (i *AtomicInt32) Set(n int32) { + atomic.StoreInt32((*int32)(i), n) +} + +func (i *AtomicInt32) Get() int32 { + return atomic.LoadInt32((*int32)(i)) +} + +func (i *AtomicInt32) CompareAndSwap(oldval, newval int32) (swapped bool) { + return atomic.CompareAndSwapInt32((*int32)(i), oldval, newval) +} + +type AtomicUint32 uint32 + +func (i *AtomicUint32) Add(n uint32) uint32 { + return atomic.AddUint32((*uint32)(i), n) +} + +func (i *AtomicUint32) Set(n uint32) { + atomic.StoreUint32((*uint32)(i), n) +} + +func (i *AtomicUint32) Get() uint32 { + return atomic.LoadUint32((*uint32)(i)) +} + +func (i *AtomicUint32) CompareAndSwap(oldval, newval uint32) (swapped bool) { + return atomic.CompareAndSwapUint32((*uint32)(i), oldval, newval) +} + +type AtomicInt64 int64 + +func (i *AtomicInt64) Add(n int64) int64 { + return atomic.AddInt64((*int64)(i), n) +} + +func (i *AtomicInt64) Set(n int64) { + atomic.StoreInt64((*int64)(i), n) +} + +func (i *AtomicInt64) Get() int64 { + return atomic.LoadInt64((*int64)(i)) +} + +func (i *AtomicInt64) CompareAndSwap(oldval, newval int64) (swapped bool) { + return atomic.CompareAndSwapInt64((*int64)(i), oldval, newval) +} + +type AtomicDuration int64 + +func (d *AtomicDuration) Add(duration time.Duration) time.Duration { + return time.Duration(atomic.AddInt64((*int64)(d), int64(duration))) +} + +func (d *AtomicDuration) Set(duration time.Duration) { + atomic.StoreInt64((*int64)(d), int64(duration)) +} + +func (d *AtomicDuration) Get() time.Duration { + return time.Duration(atomic.LoadInt64((*int64)(d))) +} + +func (d *AtomicDuration) CompareAndSwap(oldval, newval time.Duration) (swapped bool) { + return atomic.CompareAndSwapInt64((*int64)(d), int64(oldval), int64(newval)) +} + +// AtomicString gives you atomic-style APIs for string, but +// it's only a convenience wrapper that uses a mutex. So, it's +// not as efficient as the rest of the atomic types. +type AtomicString struct { + mu sync.Mutex + str string +} + +func (s *AtomicString) Set(str string) { + s.mu.Lock() + s.str = str + s.mu.Unlock() +} + +func (s *AtomicString) Get() string { + s.mu.Lock() + str := s.str + s.mu.Unlock() + return str +} + +func (s *AtomicString) CompareAndSwap(oldval, newval string) (swqpped bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.str == oldval { + s.str = newval + return true + } + return false +} diff --git a/vendor/github.com/ngaut/sync2/atomic_test.go b/vendor/github.com/ngaut/sync2/atomic_test.go new file mode 100644 index 0000000..7261159 --- /dev/null +++ b/vendor/github.com/ngaut/sync2/atomic_test.go @@ -0,0 +1,32 @@ +// Copyright 2013, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "testing" +) + +func TestAtomicString(t *testing.T) { + var s AtomicString + if s.Get() != "" { + t.Errorf("want empty, got %s", s.Get()) + } + s.Set("a") + if s.Get() != "a" { + t.Errorf("want a, got %s", s.Get()) + } + if s.CompareAndSwap("b", "c") { + t.Errorf("want false, got true") + } + if s.Get() != "a" { + t.Errorf("want a, got %s", s.Get()) + } + if !s.CompareAndSwap("a", "c") { + t.Errorf("want true, got false") + } + if s.Get() != "c" { + t.Errorf("want c, got %s", s.Get()) + } +} diff --git a/vendor/github.com/ngaut/sync2/cond.go b/vendor/github.com/ngaut/sync2/cond.go new file mode 100644 index 0000000..dea11ae --- /dev/null +++ b/vendor/github.com/ngaut/sync2/cond.go @@ -0,0 +1,56 @@ +// Copyright 2013, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "sync" +) + +// Cond is an alternate implementation of sync.Cond +type Cond struct { + L sync.Locker + sema chan struct{} + waiters AtomicInt64 +} + +func NewCond(l sync.Locker) *Cond { + return &Cond{L: l, sema: make(chan struct{})} +} + +func (c *Cond) Wait() { + c.waiters.Add(1) + c.L.Unlock() + <-c.sema + c.L.Lock() +} + +func (c *Cond) Signal() { + for { + w := c.waiters.Get() + if w == 0 { + return + } + if c.waiters.CompareAndSwap(w, w-1) { + break + } + } + c.sema <- struct{}{} +} + +func (c *Cond) Broadcast() { + var w int64 + for { + w = c.waiters.Get() + if w == 0 { + return + } + if c.waiters.CompareAndSwap(w, 0) { + break + } + } + for i := int64(0); i < w; i++ { + c.sema <- struct{}{} + } +} diff --git a/vendor/github.com/ngaut/sync2/cond_test.go b/vendor/github.com/ngaut/sync2/cond_test.go new file mode 100644 index 0000000..4dd3929 --- /dev/null +++ b/vendor/github.com/ngaut/sync2/cond_test.go @@ -0,0 +1,276 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package sync2 + +import ( + "fmt" + "runtime" + "sync" + "testing" +) + +func TestCondSignal(t *testing.T) { + var m sync.Mutex + c := NewCond(&m) + n := 2 + running := make(chan bool, n) + awake := make(chan bool, n) + for i := 0; i < n; i++ { + go func() { + m.Lock() + running <- true + c.Wait() + awake <- true + m.Unlock() + }() + } + for i := 0; i < n; i++ { + <-running // Wait for everyone to run. + } + for n > 0 { + select { + case <-awake: + t.Fatal("goroutine not asleep") + default: + } + m.Lock() + c.Signal() + m.Unlock() + <-awake // Will deadlock if no goroutine wakes up + select { + case <-awake: + t.Fatal("too many goroutines awake") + default: + } + n-- + } + c.Signal() +} + +func TestCondSignalGenerations(t *testing.T) { + var m sync.Mutex + c := NewCond(&m) + n := 100 + running := make(chan bool, n) + awake := make(chan int, n) + for i := 0; i < n; i++ { + go func(i int) { + m.Lock() + running <- true + c.Wait() + awake <- i + m.Unlock() + }(i) + if i > 0 { + a := <-awake + if a != i-1 { + t.Fatalf("wrong goroutine woke up: want %d, got %d", i-1, a) + } + } + <-running + m.Lock() + c.Signal() + m.Unlock() + } +} + +func TestCondBroadcast(t *testing.T) { + var m sync.Mutex + c := NewCond(&m) + n := 200 + running := make(chan int, n) + awake := make(chan int, n) + exit := false + for i := 0; i < n; i++ { + go func(g int) { + m.Lock() + for !exit { + running <- g + c.Wait() + awake <- g + } + m.Unlock() + }(i) + } + for i := 0; i < n; i++ { + for i := 0; i < n; i++ { + <-running // Will deadlock unless n are running. + } + if i == n-1 { + m.Lock() + exit = true + m.Unlock() + } + select { + case <-awake: + t.Fatal("goroutine not asleep") + default: + } + m.Lock() + c.Broadcast() + m.Unlock() + seen := make([]bool, n) + for i := 0; i < n; i++ { + g := <-awake + if seen[g] { + t.Fatal("goroutine woke up twice") + } + seen[g] = true + } + } + select { + case <-running: + t.Fatal("goroutine did not exit") + default: + } + c.Broadcast() +} + +func TestRace(t *testing.T) { + x := 0 + c := NewCond(&sync.Mutex{}) + done := make(chan bool) + go func() { + c.L.Lock() + x = 1 + c.Wait() + if x != 2 { + t.Fatal("want 2") + } + x = 3 + c.Signal() + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 1 { + x = 2 + c.Signal() + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 2 { + c.Wait() + if x != 3 { + t.Fatal("want 3") + } + break + } + if x == 3 { + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + <-done + <-done + <-done +} + +// Bench: Rename this function to TestBench for running benchmarks +func Bench(t *testing.T) { + waitvals := []int{1, 2, 4, 8} + maxprocs := []int{1, 2, 4} + fmt.Printf("procs\twaiters\told\tnew\tdelta\n") + for _, procs := range maxprocs { + runtime.GOMAXPROCS(procs) + for _, waiters := range waitvals { + oldbench := func(b *testing.B) { + benchmarkCond(b, waiters) + } + oldbr := testing.Benchmark(oldbench) + newbench := func(b *testing.B) { + benchmarkCond2(b, waiters) + } + newbr := testing.Benchmark(newbench) + oldns := oldbr.NsPerOp() + newns := newbr.NsPerOp() + percent := float64(newns-oldns) * 100.0 / float64(oldns) + fmt.Printf("%d\t%d\t%d\t%d\t%6.2f%%\n", procs, waiters, oldns, newns, percent) + } + } +} + +func benchmarkCond2(b *testing.B, waiters int) { + c := NewCond(&sync.Mutex{}) + done := make(chan bool) + id := 0 + + for routine := 0; routine < waiters+1; routine++ { + go func() { + for i := 0; i < b.N; i++ { + c.L.Lock() + if id == -1 { + c.L.Unlock() + break + } + id++ + if id == waiters+1 { + id = 0 + c.Broadcast() + } else { + c.Wait() + } + c.L.Unlock() + } + c.L.Lock() + id = -1 + c.Broadcast() + c.L.Unlock() + done <- true + }() + } + for routine := 0; routine < waiters+1; routine++ { + <-done + } +} + +func benchmarkCond(b *testing.B, waiters int) { + c := sync.NewCond(&sync.Mutex{}) + done := make(chan bool) + id := 0 + + for routine := 0; routine < waiters+1; routine++ { + go func() { + for i := 0; i < b.N; i++ { + c.L.Lock() + if id == -1 { + c.L.Unlock() + break + } + id++ + if id == waiters+1 { + id = 0 + c.Broadcast() + } else { + c.Wait() + } + c.L.Unlock() + } + c.L.Lock() + id = -1 + c.Broadcast() + c.L.Unlock() + done <- true + }() + } + for routine := 0; routine < waiters+1; routine++ { + <-done + } +} diff --git a/vendor/github.com/ngaut/sync2/semaphore.go b/vendor/github.com/ngaut/sync2/semaphore.go new file mode 100644 index 0000000..190a27d --- /dev/null +++ b/vendor/github.com/ngaut/sync2/semaphore.go @@ -0,0 +1,55 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +// What's in a name? Channels have all you need to emulate a counting +// semaphore with a boatload of extra functionality. However, in some +// cases, you just want a familiar API. + +import ( + "time" +) + +// Semaphore is a counting semaphore with the option to +// specify a timeout. +type Semaphore struct { + slots chan struct{} + timeout time.Duration +} + +// NewSemaphore creates a Semaphore. The count parameter must be a positive +// number. A timeout of zero means that there is no timeout. +func NewSemaphore(count int, timeout time.Duration) *Semaphore { + sem := &Semaphore{ + slots: make(chan struct{}, count), + timeout: timeout, + } + for i := 0; i < count; i++ { + sem.slots <- struct{}{} + } + return sem +} + +// Acquire returns true on successful acquisition, and +// false on a timeout. +func (sem *Semaphore) Acquire() bool { + if sem.timeout == 0 { + <-sem.slots + return true + } + select { + case <-sem.slots: + return true + case <-time.After(sem.timeout): + return false + } +} + +// Release releases the acquired semaphore. You must +// not release more than the number of semaphores you've +// acquired. +func (sem *Semaphore) Release() { + sem.slots <- struct{}{} +} diff --git a/vendor/github.com/ngaut/sync2/semaphore_test.go b/vendor/github.com/ngaut/sync2/semaphore_test.go new file mode 100644 index 0000000..207c8f9 --- /dev/null +++ b/vendor/github.com/ngaut/sync2/semaphore_test.go @@ -0,0 +1,41 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "testing" + "time" +) + +func TestSemaNoTimeout(t *testing.T) { + s := NewSemaphore(1, 0) + s.Acquire() + released := false + go func() { + time.Sleep(10 * time.Millisecond) + released = true + s.Release() + }() + s.Acquire() + if !released { + t.Errorf("want true, got false") + } +} + +func TestSemaTimeout(t *testing.T) { + s := NewSemaphore(1, 5*time.Millisecond) + s.Acquire() + go func() { + time.Sleep(10 * time.Millisecond) + s.Release() + }() + if ok := s.Acquire(); ok { + t.Errorf("want false, got true") + } + time.Sleep(10 * time.Millisecond) + if ok := s.Acquire(); !ok { + t.Errorf("want true, got false") + } +} diff --git a/vendor/github.com/ngaut/sync2/service_manager.go b/vendor/github.com/ngaut/sync2/service_manager.go new file mode 100644 index 0000000..4b85e01 --- /dev/null +++ b/vendor/github.com/ngaut/sync2/service_manager.go @@ -0,0 +1,121 @@ +// Copyright 2013, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "sync" +) + +// These are the three predefined states of a service. +const ( + SERVICE_STOPPED = iota + SERVICE_RUNNING + SERVICE_SHUTTING_DOWN +) + +var stateNames = []string{ + "Stopped", + "Running", + "ShuttingDown", +} + +// ServiceManager manages the state of a service through its lifecycle. +type ServiceManager struct { + mu sync.Mutex + wg sync.WaitGroup + err error // err is the error returned from the service function. + state AtomicInt64 + // shutdown is created when the service starts and is closed when the service + // enters the SERVICE_SHUTTING_DOWN state. + shutdown chan struct{} +} + +// Go tries to change the state from SERVICE_STOPPED to SERVICE_RUNNING. +// +// If the current state is not SERVICE_STOPPED (already running), it returns +// false immediately. +// +// On successful transition, it launches the service as a goroutine and returns +// true. The service function is responsible for returning on its own when +// requested, either by regularly checking svc.IsRunning(), or by waiting for +// the svc.ShuttingDown channel to be closed. +// +// When the service func returns, the state is reverted to SERVICE_STOPPED. +func (svm *ServiceManager) Go(service func(svc *ServiceContext) error) bool { + svm.mu.Lock() + defer svm.mu.Unlock() + if !svm.state.CompareAndSwap(SERVICE_STOPPED, SERVICE_RUNNING) { + return false + } + svm.wg.Add(1) + svm.err = nil + svm.shutdown = make(chan struct{}) + go func() { + svm.err = service(&ServiceContext{ShuttingDown: svm.shutdown}) + svm.state.Set(SERVICE_STOPPED) + svm.wg.Done() + }() + return true +} + +// Stop tries to change the state from SERVICE_RUNNING to SERVICE_SHUTTING_DOWN. +// If the current state is not SERVICE_RUNNING, it returns false immediately. +// On successul transition, it waits for the service to finish, and returns true. +// You are allowed to Go() again after a Stop(). +func (svm *ServiceManager) Stop() bool { + svm.mu.Lock() + defer svm.mu.Unlock() + if !svm.state.CompareAndSwap(SERVICE_RUNNING, SERVICE_SHUTTING_DOWN) { + return false + } + // Signal the service that we've transitioned to SERVICE_SHUTTING_DOWN. + close(svm.shutdown) + svm.shutdown = nil + svm.wg.Wait() + return true +} + +// Wait waits for the service to terminate if it's currently running. +func (svm *ServiceManager) Wait() { + svm.wg.Wait() +} + +// Join waits for the service to terminate and returns the value returned by the +// service function. +func (svm *ServiceManager) Join() error { + svm.wg.Wait() + return svm.err +} + +// State returns the current state of the service. +// This should only be used to report the current state. +func (svm *ServiceManager) State() int64 { + return svm.state.Get() +} + +// StateName returns the name of the current state. +func (svm *ServiceManager) StateName() string { + return stateNames[svm.State()] +} + +// ServiceContext is passed into the service function to give it access to +// information about the running service. +type ServiceContext struct { + // ShuttingDown is a channel that the service can select on to be notified + // when it should shut down. The channel is closed when the state transitions + // from SERVICE_RUNNING to SERVICE_SHUTTING_DOWN. + ShuttingDown chan struct{} +} + +// IsRunning returns true if the ServiceContext.ShuttingDown channel has not +// been closed yet. +func (svc *ServiceContext) IsRunning() bool { + select { + case <-svc.ShuttingDown: + return false + default: + return true + } +} diff --git a/vendor/github.com/ngaut/sync2/service_manager_test.go b/vendor/github.com/ngaut/sync2/service_manager_test.go new file mode 100644 index 0000000..a41912e --- /dev/null +++ b/vendor/github.com/ngaut/sync2/service_manager_test.go @@ -0,0 +1,176 @@ +// Copyright 2013, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "fmt" + "testing" + "time" +) + +type testService struct { + activated AtomicInt64 + t *testing.T +} + +func (ts *testService) service(svc *ServiceContext) error { + if !ts.activated.CompareAndSwap(0, 1) { + ts.t.Fatalf("service called more than once") + } + for svc.IsRunning() { + time.Sleep(10 * time.Millisecond) + + } + if !ts.activated.CompareAndSwap(1, 0) { + ts.t.Fatalf("service ended more than once") + } + return nil +} + +func (ts *testService) selectService(svc *ServiceContext) error { + if !ts.activated.CompareAndSwap(0, 1) { + ts.t.Fatalf("service called more than once") + } +serviceLoop: + for svc.IsRunning() { + select { + case <-time.After(1 * time.Second): + ts.t.Errorf("service didn't stop when shutdown channel was closed") + case <-svc.ShuttingDown: + break serviceLoop + } + } + if !ts.activated.CompareAndSwap(1, 0) { + ts.t.Fatalf("service ended more than once") + } + return nil +} + +func TestServiceManager(t *testing.T) { + ts := &testService{t: t} + var sm ServiceManager + if sm.StateName() != "Stopped" { + t.Errorf("want Stopped, got %s", sm.StateName()) + } + result := sm.Go(ts.service) + if !result { + t.Errorf("want true, got false") + } + if sm.StateName() != "Running" { + t.Errorf("want Running, got %s", sm.StateName()) + } + time.Sleep(5 * time.Millisecond) + if val := ts.activated.Get(); val != 1 { + t.Errorf("want 1, got %d", val) + } + result = sm.Go(ts.service) + if result { + t.Errorf("want false, got true") + } + result = sm.Stop() + if !result { + t.Errorf("want true, got false") + } + if val := ts.activated.Get(); val != 0 { + t.Errorf("want 0, got %d", val) + } + result = sm.Stop() + if result { + t.Errorf("want false, got true") + } + sm.state.Set(SERVICE_SHUTTING_DOWN) + if sm.StateName() != "ShuttingDown" { + t.Errorf("want ShuttingDown, got %s", sm.StateName()) + } +} + +func TestServiceManagerSelect(t *testing.T) { + ts := &testService{t: t} + var sm ServiceManager + if sm.StateName() != "Stopped" { + t.Errorf("want Stopped, got %s", sm.StateName()) + } + result := sm.Go(ts.selectService) + if !result { + t.Errorf("want true, got false") + } + if sm.StateName() != "Running" { + t.Errorf("want Running, got %s", sm.StateName()) + } + time.Sleep(5 * time.Millisecond) + if val := ts.activated.Get(); val != 1 { + t.Errorf("want 1, got %d", val) + } + result = sm.Go(ts.service) + if result { + t.Errorf("want false, got true") + } + result = sm.Stop() + if !result { + t.Errorf("want true, got false") + } + if val := ts.activated.Get(); val != 0 { + t.Errorf("want 0, got %d", val) + } + result = sm.Stop() + if result { + t.Errorf("want false, got true") + } + sm.state.Set(SERVICE_SHUTTING_DOWN) + if sm.StateName() != "ShuttingDown" { + t.Errorf("want ShuttingDown, got %s", sm.StateName()) + } +} + +func TestServiceManagerWaitNotRunning(t *testing.T) { + done := make(chan struct{}) + var sm ServiceManager + go func() { + sm.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(1 * time.Second): + t.Errorf("Wait() blocked even though service wasn't running.") + } +} + +func TestServiceManagerWait(t *testing.T) { + done := make(chan struct{}) + stop := make(chan struct{}) + var sm ServiceManager + sm.Go(func(*ServiceContext) error { + <-stop + return nil + }) + go func() { + sm.Wait() + close(done) + }() + time.Sleep(100 * time.Millisecond) + select { + case <-done: + t.Errorf("Wait() didn't block while service was still running.") + default: + } + close(stop) + select { + case <-done: + case <-time.After(100 * time.Millisecond): + t.Errorf("Wait() didn't unblock when service stopped.") + } +} + +func TestServiceManagerJoin(t *testing.T) { + want := "error 123" + var sm ServiceManager + sm.Go(func(*ServiceContext) error { + return fmt.Errorf("error 123") + }) + if got := sm.Join().Error(); got != want { + t.Errorf("Join().Error() = %#v, want %#v", got, want) + } +} diff --git a/vendor/github.com/ngaut/sync2/vitess_license b/vendor/github.com/ngaut/sync2/vitess_license new file mode 100644 index 0000000..989d02e --- /dev/null +++ b/vendor/github.com/ngaut/sync2/vitess_license @@ -0,0 +1,28 @@ +Copyright 2012, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/ngaut/systimemon/LICENSE b/vendor/github.com/ngaut/systimemon/LICENSE new file mode 100644 index 0000000..8dada3e --- /dev/null +++ b/vendor/github.com/ngaut/systimemon/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/ngaut/systimemon/README.md b/vendor/github.com/ngaut/systimemon/README.md new file mode 100644 index 0000000..39e3630 --- /dev/null +++ b/vendor/github.com/ngaut/systimemon/README.md @@ -0,0 +1,2 @@ +# systimemon +System time monitor diff --git a/vendor/github.com/ngaut/systimemon/systime_mon.go b/vendor/github.com/ngaut/systimemon/systime_mon.go new file mode 100644 index 0000000..0d3defe --- /dev/null +++ b/vendor/github.com/ngaut/systimemon/systime_mon.go @@ -0,0 +1,24 @@ +package systimemon + +import ( + "time" + + "github.com/ngaut/log" +) + +// StartMonitor will call systimeErrHandler if system time jump backward. +func StartMonitor(now func() time.Time, systimeErrHandler func()) { + log.Info("start system time monitor") + tick := time.NewTicker(100 * time.Millisecond) + defer tick.Stop() + for { + last := now() + select { + case <-tick.C: + if now().Sub(last) < 0 { + log.Errorf("system time jump backward, last:%v", last) + systimeErrHandler() + } + } + } +} diff --git a/vendor/github.com/ngaut/systimemon/systime_mon_test.go b/vendor/github.com/ngaut/systimemon/systime_mon_test.go new file mode 100644 index 0000000..3f6dc1b --- /dev/null +++ b/vendor/github.com/ngaut/systimemon/systime_mon_test.go @@ -0,0 +1,28 @@ +package systimemon + +import ( + "testing" + "time" +) + +func TestSystimeMonitor(t *testing.T) { + jumpForward := false + trigged := false + go StartMonitor( + func() time.Time { + if !trigged { + trigged = true + return time.Now() + } + + return time.Now().Add(-2 * time.Second) + }, func() { + jumpForward = true + }) + + time.Sleep(1 * time.Second) + + if !jumpForward { + t.Error("should detect time error") + } +} From 53f4290cbc6ce39df5069cbc94ee8cd64e2d6216 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 12 Feb 2017 13:19:46 +0200 Subject: [PATCH 097/104] adding net/context for dependency --- vendor/golang.org/x/net/context/context.go | 156 +++++ .../golang.org/x/net/context/context_test.go | 583 ++++++++++++++++++ .../x/net/context/ctxhttp/ctxhttp.go | 74 +++ .../x/net/context/ctxhttp/ctxhttp_17_test.go | 28 + .../x/net/context/ctxhttp/ctxhttp_pre17.go | 147 +++++ .../net/context/ctxhttp/ctxhttp_pre17_test.go | 79 +++ .../x/net/context/ctxhttp/ctxhttp_test.go | 105 ++++ vendor/golang.org/x/net/context/go17.go | 72 +++ vendor/golang.org/x/net/context/pre_go17.go | 300 +++++++++ .../x/net/context/withtimeout_test.go | 26 + 10 files changed, 1570 insertions(+) create mode 100644 vendor/golang.org/x/net/context/context.go create mode 100644 vendor/golang.org/x/net/context/context_test.go create mode 100644 vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go create mode 100644 vendor/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go create mode 100644 vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go create mode 100644 vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go create mode 100644 vendor/golang.org/x/net/context/ctxhttp/ctxhttp_test.go create mode 100644 vendor/golang.org/x/net/context/go17.go create mode 100644 vendor/golang.org/x/net/context/pre_go17.go create mode 100644 vendor/golang.org/x/net/context/withtimeout_test.go diff --git a/vendor/golang.org/x/net/context/context.go b/vendor/golang.org/x/net/context/context.go new file mode 100644 index 0000000..f143ed6 --- /dev/null +++ b/vendor/golang.org/x/net/context/context.go @@ -0,0 +1,156 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context defines the Context type, which carries deadlines, +// cancelation signals, and other request-scoped values across API boundaries +// and between processes. +// +// Incoming requests to a server should create a Context, and outgoing calls to +// servers should accept a Context. The chain of function calls between must +// propagate the Context, optionally replacing it with a modified copy created +// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// +// Programs that use Contexts should follow these rules to keep interfaces +// consistent across packages and enable static analysis tools to check context +// propagation: +// +// Do not store Contexts inside a struct type; instead, pass a Context +// explicitly to each function that needs it. The Context should be the first +// parameter, typically named ctx: +// +// func DoSomething(ctx context.Context, arg Arg) error { +// // ... use ctx ... +// } +// +// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// if you are unsure about which Context to use. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The same Context may be passed to functions running in different goroutines; +// Contexts are safe for simultaneous use by multiple goroutines. +// +// See http://blog.golang.org/context for example code for a server that uses +// Contexts. +package context // import "golang.org/x/net/context" + +import "time" + +// A Context carries a deadline, a cancelation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context interface { + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // + // WithCancel arranges for Done to be closed when cancel is called; + // WithDeadline arranges for Done to be closed when the deadline + // expires; WithTimeout arranges for Done to be closed when the timeout + // elapses. + // + // Done is provided for use in select statements: + // + // // Stream generates values with DoSomething and sends them to out + // // until DoSomething returns an error or ctx.Done is closed. + // func Stream(ctx context.Context, out chan<- Value) error { + // for { + // v, err := DoSomething(ctx) + // if err != nil { + // return err + // } + // select { + // case <-ctx.Done(): + // return ctx.Err() + // case out <- v: + // } + // } + // } + // + // See http://blog.golang.org/pipelines for more examples of how to use + // a Done channel for cancelation. + Done() <-chan struct{} + + // Err returns a non-nil error value after Done is closed. Err returns + // Canceled if the context was canceled or DeadlineExceeded if the + // context's deadline passed. No other values for Err are defined. + // After Done is closed, successive calls to Err return the same value. + Err() error + + // Value returns the value associated with this context for key, or nil + // if no value is associated with key. Successive calls to Value with + // the same key returns the same result. + // + // Use context values only for request-scoped data that transits + // processes and API boundaries, not for passing optional parameters to + // functions. + // + // A key identifies a specific value in a Context. Functions that wish + // to store values in Context typically allocate a key in a global + // variable then use that key as the argument to context.WithValue and + // Context.Value. A key can be any type that supports equality; + // packages should define keys as an unexported type to avoid + // collisions. + // + // Packages that define a Context key should provide type-safe accessors + // for the values stores using that key: + // + // // Package user defines a User type that's stored in Contexts. + // package user + // + // import "golang.org/x/net/context" + // + // // User is the type of value stored in the Contexts. + // type User struct {...} + // + // // key is an unexported type for keys defined in this package. + // // This prevents collisions with keys defined in other packages. + // type key int + // + // // userKey is the key for user.User values in Contexts. It is + // // unexported; clients use user.NewContext and user.FromContext + // // instead of using this key directly. + // var userKey key = 0 + // + // // NewContext returns a new Context that carries value u. + // func NewContext(ctx context.Context, u *User) context.Context { + // return context.WithValue(ctx, userKey, u) + // } + // + // // FromContext returns the User value stored in ctx, if any. + // func FromContext(ctx context.Context) (*User, bool) { + // u, ok := ctx.Value(userKey).(*User) + // return u, ok + // } + Value(key interface{}) interface{} +} + +// Background returns a non-nil, empty Context. It is never canceled, has no +// values, and has no deadline. It is typically used by the main function, +// initialization, and tests, and as the top-level Context for incoming +// requests. +func Background() Context { + return background +} + +// TODO returns a non-nil, empty Context. Code should use context.TODO when +// it's unclear which Context to use or it is not yet available (because the +// surrounding function has not yet been extended to accept a Context +// parameter). TODO is recognized by static analysis tools that determine +// whether Contexts are propagated correctly in a program. +func TODO() Context { + return todo +} + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc func() diff --git a/vendor/golang.org/x/net/context/context_test.go b/vendor/golang.org/x/net/context/context_test.go new file mode 100644 index 0000000..6284413 --- /dev/null +++ b/vendor/golang.org/x/net/context/context_test.go @@ -0,0 +1,583 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package context + +import ( + "fmt" + "math/rand" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +// otherContext is a Context that's not one of the types defined in context.go. +// This lets us test code paths that differ based on the underlying type of the +// Context. +type otherContext struct { + Context +} + +func TestBackground(t *testing.T) { + c := Background() + if c == nil { + t.Fatalf("Background returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.Background"; got != want { + t.Errorf("Background().String() = %q want %q", got, want) + } +} + +func TestTODO(t *testing.T) { + c := TODO() + if c == nil { + t.Fatalf("TODO returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.TODO"; got != want { + t.Errorf("TODO().String() = %q want %q", got, want) + } +} + +func TestWithCancel(t *testing.T) { + c1, cancel := WithCancel(Background()) + + if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { + t.Errorf("c1.String() = %q want %q", got, want) + } + + o := otherContext{c1} + c2, _ := WithCancel(o) + contexts := []Context{c1, o, c2} + + for i, c := range contexts { + if d := c.Done(); d == nil { + t.Errorf("c[%d].Done() == %v want non-nil", i, d) + } + if e := c.Err(); e != nil { + t.Errorf("c[%d].Err() == %v want nil", i, e) + } + + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + } + + cancel() + time.Sleep(100 * time.Millisecond) // let cancelation propagate + + for i, c := range contexts { + select { + case <-c.Done(): + default: + t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) + } + if e := c.Err(); e != Canceled { + t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) + } + } +} + +func TestParentFinishesChild(t *testing.T) { + // Context tree: + // parent -> cancelChild + // parent -> valueChild -> timerChild + parent, cancel := WithCancel(Background()) + cancelChild, stop := WithCancel(parent) + defer stop() + valueChild := WithValue(parent, "key", "value") + timerChild, stop := WithTimeout(valueChild, 10000*time.Hour) + defer stop() + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-cancelChild.Done(): + t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) + case x := <-timerChild.Done(): + t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) + case x := <-valueChild.Done(): + t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) + default: + } + + // The parent's children should contain the two cancelable children. + pc := parent.(*cancelCtx) + cc := cancelChild.(*cancelCtx) + tc := timerChild.(*timerCtx) + pc.mu.Lock() + if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { + t.Errorf("bad linkage: pc.children = %v, want %v and %v", + pc.children, cc, tc) + } + pc.mu.Unlock() + + if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) + } + if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) + } + + cancel() + + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) + } + pc.mu.Unlock() + + // parent and children should all be finished. + check := func(ctx Context, name string) { + select { + case <-ctx.Done(): + default: + t.Errorf("<-%s.Done() blocked, but shouldn't have", name) + } + if e := ctx.Err(); e != Canceled { + t.Errorf("%s.Err() == %v want %v", name, e, Canceled) + } + } + check(parent, "parent") + check(cancelChild, "cancelChild") + check(valueChild, "valueChild") + check(timerChild, "timerChild") + + // WithCancel should return a canceled context on a canceled parent. + precanceledChild := WithValue(parent, "key", "value") + select { + case <-precanceledChild.Done(): + default: + t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") + } + if e := precanceledChild.Err(); e != Canceled { + t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) + } +} + +func TestChildFinishesFirst(t *testing.T) { + cancelable, stop := WithCancel(Background()) + defer stop() + for _, parent := range []Context{Background(), cancelable} { + child, cancel := WithCancel(parent) + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-child.Done(): + t.Errorf("<-child.Done() == %v want nothing (it should block)", x) + default: + } + + cc := child.(*cancelCtx) + pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background() + if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { + t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) + } + + if pcok { + pc.mu.Lock() + if len(pc.children) != 1 || !pc.children[cc] { + t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) + } + pc.mu.Unlock() + } + + cancel() + + if pcok { + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) + } + pc.mu.Unlock() + } + + // child should be finished. + select { + case <-child.Done(): + default: + t.Errorf("<-child.Done() blocked, but shouldn't have") + } + if e := child.Err(); e != Canceled { + t.Errorf("child.Err() == %v want %v", e, Canceled) + } + + // parent should not be finished. + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + default: + } + if e := parent.Err(); e != nil { + t.Errorf("parent.Err() == %v want nil", e) + } + } +} + +func testDeadline(c Context, wait time.Duration, t *testing.T) { + select { + case <-time.After(wait): + t.Fatalf("context should have timed out") + case <-c.Done(): + } + if e := c.Err(); e != DeadlineExceeded { + t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) + } +} + +func TestDeadline(t *testing.T) { + t.Parallel() + const timeUnit = 500 * time.Millisecond + c, _ := WithDeadline(Background(), time.Now().Add(1*timeUnit)) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, 2*timeUnit, t) + + c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit)) + o := otherContext{c} + testDeadline(o, 2*timeUnit, t) + + c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit)) + o = otherContext{c} + c, _ = WithDeadline(o, time.Now().Add(3*timeUnit)) + testDeadline(c, 2*timeUnit, t) +} + +func TestTimeout(t *testing.T) { + t.Parallel() + const timeUnit = 500 * time.Millisecond + c, _ := WithTimeout(Background(), 1*timeUnit) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, 2*timeUnit, t) + + c, _ = WithTimeout(Background(), 1*timeUnit) + o := otherContext{c} + testDeadline(o, 2*timeUnit, t) + + c, _ = WithTimeout(Background(), 1*timeUnit) + o = otherContext{c} + c, _ = WithTimeout(o, 3*timeUnit) + testDeadline(c, 2*timeUnit, t) +} + +func TestCanceledTimeout(t *testing.T) { + t.Parallel() + const timeUnit = 500 * time.Millisecond + c, _ := WithTimeout(Background(), 2*timeUnit) + o := otherContext{c} + c, cancel := WithTimeout(o, 4*timeUnit) + cancel() + time.Sleep(1 * timeUnit) // let cancelation propagate + select { + case <-c.Done(): + default: + t.Errorf("<-c.Done() blocked, but shouldn't have") + } + if e := c.Err(); e != Canceled { + t.Errorf("c.Err() == %v want %v", e, Canceled) + } +} + +type key1 int +type key2 int + +var k1 = key1(1) +var k2 = key2(1) // same int as k1, different type +var k3 = key2(3) // same type as k2, different int + +func TestValues(t *testing.T) { + check := func(c Context, nm, v1, v2, v3 string) { + if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { + t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) + } + if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { + t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) + } + if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { + t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) + } + } + + c0 := Background() + check(c0, "c0", "", "", "") + + c1 := WithValue(Background(), k1, "c1k1") + check(c1, "c1", "c1k1", "", "") + + if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want { + t.Errorf("c.String() = %q want %q", got, want) + } + + c2 := WithValue(c1, k2, "c2k2") + check(c2, "c2", "c1k1", "c2k2", "") + + c3 := WithValue(c2, k3, "c3k3") + check(c3, "c2", "c1k1", "c2k2", "c3k3") + + c4 := WithValue(c3, k1, nil) + check(c4, "c4", "", "c2k2", "c3k3") + + o0 := otherContext{Background()} + check(o0, "o0", "", "", "") + + o1 := otherContext{WithValue(Background(), k1, "c1k1")} + check(o1, "o1", "c1k1", "", "") + + o2 := WithValue(o1, k2, "o2k2") + check(o2, "o2", "c1k1", "o2k2", "") + + o3 := otherContext{c4} + check(o3, "o3", "", "c2k2", "c3k3") + + o4 := WithValue(o3, k3, nil) + check(o4, "o4", "", "c2k2", "") +} + +func TestAllocs(t *testing.T) { + bg := Background() + for _, test := range []struct { + desc string + f func() + limit float64 + gccgoLimit float64 + }{ + { + desc: "Background()", + f: func() { Background() }, + limit: 0, + gccgoLimit: 0, + }, + { + desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), + f: func() { + c := WithValue(bg, k1, nil) + c.Value(k1) + }, + limit: 3, + gccgoLimit: 3, + }, + { + desc: "WithTimeout(bg, 15*time.Millisecond)", + f: func() { + c, _ := WithTimeout(bg, 15*time.Millisecond) + <-c.Done() + }, + limit: 8, + gccgoLimit: 16, + }, + { + desc: "WithCancel(bg)", + f: func() { + c, cancel := WithCancel(bg) + cancel() + <-c.Done() + }, + limit: 5, + gccgoLimit: 8, + }, + { + desc: "WithTimeout(bg, 100*time.Millisecond)", + f: func() { + c, cancel := WithTimeout(bg, 100*time.Millisecond) + cancel() + <-c.Done() + }, + limit: 8, + gccgoLimit: 25, + }, + } { + limit := test.limit + if runtime.Compiler == "gccgo" { + // gccgo does not yet do escape analysis. + // TODO(iant): Remove this when gccgo does do escape analysis. + limit = test.gccgoLimit + } + if n := testing.AllocsPerRun(100, test.f); n > limit { + t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) + } + } +} + +func TestSimultaneousCancels(t *testing.T) { + root, cancel := WithCancel(Background()) + m := map[Context]CancelFunc{root: cancel} + q := []Context{root} + // Create a tree of contexts. + for len(q) != 0 && len(m) < 100 { + parent := q[0] + q = q[1:] + for i := 0; i < 4; i++ { + ctx, cancel := WithCancel(parent) + m[ctx] = cancel + q = append(q, ctx) + } + } + // Start all the cancels in a random order. + var wg sync.WaitGroup + wg.Add(len(m)) + for _, cancel := range m { + go func(cancel CancelFunc) { + cancel() + wg.Done() + }(cancel) + } + // Wait on all the contexts in a random order. + for ctx := range m { + select { + case <-ctx.Done(): + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n]) + } + } + // Wait for all the cancel functions to return. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n]) + } +} + +func TestInterlockedCancels(t *testing.T) { + parent, cancelParent := WithCancel(Background()) + child, cancelChild := WithCancel(parent) + go func() { + parent.Done() + cancelChild() + }() + cancelParent() + select { + case <-child.Done(): + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n]) + } +} + +func TestLayersCancel(t *testing.T) { + testLayers(t, time.Now().UnixNano(), false) +} + +func TestLayersTimeout(t *testing.T) { + testLayers(t, time.Now().UnixNano(), true) +} + +func testLayers(t *testing.T, seed int64, testTimeout bool) { + rand.Seed(seed) + errorf := func(format string, a ...interface{}) { + t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) + } + const ( + timeout = 200 * time.Millisecond + minLayers = 30 + ) + type value int + var ( + vals []*value + cancels []CancelFunc + numTimers int + ctx = Background() + ) + for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { + switch rand.Intn(3) { + case 0: + v := new(value) + ctx = WithValue(ctx, v, v) + vals = append(vals, v) + case 1: + var cancel CancelFunc + ctx, cancel = WithCancel(ctx) + cancels = append(cancels, cancel) + case 2: + var cancel CancelFunc + ctx, cancel = WithTimeout(ctx, timeout) + cancels = append(cancels, cancel) + numTimers++ + } + } + checkValues := func(when string) { + for _, key := range vals { + if val := ctx.Value(key).(*value); key != val { + errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) + } + } + } + select { + case <-ctx.Done(): + errorf("ctx should not be canceled yet") + default: + } + if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { + t.Errorf("ctx.String() = %q want prefix %q", s, prefix) + } + t.Log(ctx) + checkValues("before cancel") + if testTimeout { + select { + case <-ctx.Done(): + case <-time.After(timeout + 100*time.Millisecond): + errorf("ctx should have timed out") + } + checkValues("after timeout") + } else { + cancel := cancels[rand.Intn(len(cancels))] + cancel() + select { + case <-ctx.Done(): + default: + errorf("ctx should be canceled") + } + checkValues("after cancel") + } +} + +func TestCancelRemoves(t *testing.T) { + checkChildren := func(when string, ctx Context, want int) { + if got := len(ctx.(*cancelCtx).children); got != want { + t.Errorf("%s: context has %d children, want %d", when, got, want) + } + } + + ctx, _ := WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel := WithCancel(ctx) + checkChildren("with WithCancel child ", ctx, 1) + cancel() + checkChildren("after cancelling WithCancel child", ctx, 0) + + ctx, _ = WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel = WithTimeout(ctx, 60*time.Minute) + checkChildren("with WithTimeout child ", ctx, 1) + cancel() + checkChildren("after cancelling WithTimeout child", ctx, 0) +} diff --git a/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go new file mode 100644 index 0000000..606cf1f --- /dev/null +++ b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go @@ -0,0 +1,74 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +// Package ctxhttp provides helper functions for performing context-aware HTTP requests. +package ctxhttp // import "golang.org/x/net/context/ctxhttp" + +import ( + "io" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/context" +) + +// Do sends an HTTP request with the provided http.Client and returns +// an HTTP response. +// +// If the client is nil, http.DefaultClient is used. +// +// The provided ctx must be non-nil. If it is canceled or times out, +// ctx.Err() will be returned. +func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req.WithContext(ctx)) + // If we got an error, and the context has been canceled, + // the context's error is probably more useful. + if err != nil { + select { + case <-ctx.Done(): + err = ctx.Err() + default: + } + } + return resp, err +} + +// Get issues a GET request via the Do function. +func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Head issues a HEAD request via the Do function. +func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Post issues a POST request via the Do function. +func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return Do(ctx, client, req) +} + +// PostForm issues a POST request via the Do function. +func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { + return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} diff --git a/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go new file mode 100644 index 0000000..9f0f90f --- /dev/null +++ b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go @@ -0,0 +1,28 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9,go1.7 + +package ctxhttp + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "context" +) + +func TestGo17Context(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "ok") + })) + ctx := context.Background() + resp, err := Get(ctx, http.DefaultClient, ts.URL) + if resp == nil || err != nil { + t.Fatalf("error received from client: %v %v", err, resp) + } + resp.Body.Close() +} diff --git a/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go new file mode 100644 index 0000000..926870c --- /dev/null +++ b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go @@ -0,0 +1,147 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package ctxhttp // import "golang.org/x/net/context/ctxhttp" + +import ( + "io" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/context" +) + +func nop() {} + +var ( + testHookContextDoneBeforeHeaders = nop + testHookDoReturned = nop + testHookDidBodyClose = nop +) + +// Do sends an HTTP request with the provided http.Client and returns an HTTP response. +// If the client is nil, http.DefaultClient is used. +// If the context is canceled or times out, ctx.Err() will be returned. +func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { + if client == nil { + client = http.DefaultClient + } + + // TODO(djd): Respect any existing value of req.Cancel. + cancel := make(chan struct{}) + req.Cancel = cancel + + type responseAndError struct { + resp *http.Response + err error + } + result := make(chan responseAndError, 1) + + // Make local copies of test hooks closed over by goroutines below. + // Prevents data races in tests. + testHookDoReturned := testHookDoReturned + testHookDidBodyClose := testHookDidBodyClose + + go func() { + resp, err := client.Do(req) + testHookDoReturned() + result <- responseAndError{resp, err} + }() + + var resp *http.Response + + select { + case <-ctx.Done(): + testHookContextDoneBeforeHeaders() + close(cancel) + // Clean up after the goroutine calling client.Do: + go func() { + if r := <-result; r.resp != nil { + testHookDidBodyClose() + r.resp.Body.Close() + } + }() + return nil, ctx.Err() + case r := <-result: + var err error + resp, err = r.resp, r.err + if err != nil { + return resp, err + } + } + + c := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + close(cancel) + case <-c: + // The response's Body is closed. + } + }() + resp.Body = ¬ifyingReader{resp.Body, c} + + return resp, nil +} + +// Get issues a GET request via the Do function. +func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Head issues a HEAD request via the Do function. +func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Post issues a POST request via the Do function. +func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return Do(ctx, client, req) +} + +// PostForm issues a POST request via the Do function. +func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { + return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + +// notifyingReader is an io.ReadCloser that closes the notify channel after +// Close is called or a Read fails on the underlying ReadCloser. +type notifyingReader struct { + io.ReadCloser + notify chan<- struct{} +} + +func (r *notifyingReader) Read(p []byte) (int, error) { + n, err := r.ReadCloser.Read(p) + if err != nil && r.notify != nil { + close(r.notify) + r.notify = nil + } + return n, err +} + +func (r *notifyingReader) Close() error { + err := r.ReadCloser.Close() + if r.notify != nil { + close(r.notify) + r.notify = nil + } + return err +} diff --git a/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go new file mode 100644 index 0000000..9159cf0 --- /dev/null +++ b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go @@ -0,0 +1,79 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9,!go1.7 + +package ctxhttp + +import ( + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "golang.org/x/net/context" +) + +// golang.org/issue/14065 +func TestClosesResponseBodyOnCancel(t *testing.T) { + defer func() { testHookContextDoneBeforeHeaders = nop }() + defer func() { testHookDoReturned = nop }() + defer func() { testHookDidBodyClose = nop }() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // closed when Do enters select case <-ctx.Done() + enteredDonePath := make(chan struct{}) + + testHookContextDoneBeforeHeaders = func() { + close(enteredDonePath) + } + + testHookDoReturned = func() { + // We now have the result (the Flush'd headers) at least, + // so we can cancel the request. + cancel() + + // But block the client.Do goroutine from sending + // until Do enters into the <-ctx.Done() path, since + // otherwise if both channels are readable, select + // picks a random one. + <-enteredDonePath + } + + sawBodyClose := make(chan struct{}) + testHookDidBodyClose = func() { close(sawBodyClose) } + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + req, _ := http.NewRequest("GET", ts.URL, nil) + _, doErr := Do(ctx, c, req) + + select { + case <-sawBodyClose: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for body to close") + } + + if doErr != ctx.Err() { + t.Errorf("Do error = %v; want %v", doErr, ctx.Err()) + } +} + +type noteCloseConn struct { + net.Conn + onceClose sync.Once + closefn func() +} + +func (c *noteCloseConn) Close() error { + c.onceClose.Do(c.closefn) + return c.Conn.Close() +} diff --git a/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_test.go b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_test.go new file mode 100644 index 0000000..1e41551 --- /dev/null +++ b/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_test.go @@ -0,0 +1,105 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package ctxhttp + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/net/context" +) + +const ( + requestDuration = 100 * time.Millisecond + requestBody = "ok" +) + +func okHandler(w http.ResponseWriter, r *http.Request) { + time.Sleep(requestDuration) + io.WriteString(w, requestBody) +} + +func TestNoTimeout(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(okHandler)) + defer ts.Close() + + ctx := context.Background() + res, err := Get(ctx, nil, ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != requestBody { + t.Errorf("body = %q; want %q", slurp, requestBody) + } +} + +func TestCancelBeforeHeaders(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + blockServer := make(chan struct{}) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cancel() + <-blockServer + io.WriteString(w, requestBody) + })) + defer ts.Close() + defer close(blockServer) + + res, err := Get(ctx, nil, ts.URL) + if err == nil { + res.Body.Close() + t.Fatal("Get returned unexpected nil error") + } + if err != context.Canceled { + t.Errorf("err = %v; want %v", err, context.Canceled) + } +} + +func TestCancelAfterHangingRequest(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + <-w.(http.CloseNotifier).CloseNotify() + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + resp, err := Get(ctx, nil, ts.URL) + if err != nil { + t.Fatalf("unexpected error in Get: %v", err) + } + + // Cancel befer reading the body. + // Reading Request.Body should fail, since the request was + // canceled before anything was written. + cancel() + + done := make(chan struct{}) + + go func() { + b, err := ioutil.ReadAll(resp.Body) + if len(b) != 0 || err == nil { + t.Errorf(`Read got (%q, %v); want ("", error)`, b, err) + } + close(done) + }() + + select { + case <-time.After(1 * time.Second): + t.Errorf("Test timed out") + case <-done: + } +} diff --git a/vendor/golang.org/x/net/context/go17.go b/vendor/golang.org/x/net/context/go17.go new file mode 100644 index 0000000..d20f52b --- /dev/null +++ b/vendor/golang.org/x/net/context/go17.go @@ -0,0 +1,72 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +package context + +import ( + "context" // standard library's context, as of Go 1.7 + "time" +) + +var ( + todo = context.TODO() + background = context.Background() +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = context.Canceled + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = context.DeadlineExceeded + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + ctx, f := context.WithCancel(parent) + return ctx, CancelFunc(f) +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + ctx, f := context.WithDeadline(parent, deadline) + return ctx, CancelFunc(f) +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return context.WithValue(parent, key, val) +} diff --git a/vendor/golang.org/x/net/context/pre_go17.go b/vendor/golang.org/x/net/context/pre_go17.go new file mode 100644 index 0000000..0f35592 --- /dev/null +++ b/vendor/golang.org/x/net/context/pre_go17.go @@ -0,0 +1,300 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package context + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// An emptyCtx is never canceled, has no values, and has no deadline. It is not +// struct{}, since vars of this type must have distinct addresses. +type emptyCtx int + +func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*emptyCtx) Done() <-chan struct{} { + return nil +} + +func (*emptyCtx) Err() error { + return nil +} + +func (*emptyCtx) Value(key interface{}) interface{} { + return nil +} + +func (e *emptyCtx) String() string { + switch e { + case background: + return "context.Background" + case todo: + return "context.TODO" + } + return "unknown empty Context" +} + +var ( + background = new(emptyCtx) + todo = new(emptyCtx) +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = errors.New("context canceled") + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = errors.New("context deadline exceeded") + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := newCancelCtx(parent) + propagateCancel(parent, c) + return c, func() { c.cancel(true, Canceled) } +} + +// newCancelCtx returns an initialized cancelCtx. +func newCancelCtx(parent Context) *cancelCtx { + return &cancelCtx{ + Context: parent, + done: make(chan struct{}), + } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent Context, child canceler) { + if parent.Done() == nil { + return // parent is never canceled + } + if p, ok := parentCancelCtx(parent); ok { + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err) + } else { + if p.children == nil { + p.children = make(map[canceler]bool) + } + p.children[child] = true + } + p.mu.Unlock() + } else { + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err()) + case <-child.Done(): + } + }() + } +} + +// parentCancelCtx follows a chain of parent references until it finds a +// *cancelCtx. This function understands how each of the concrete types in this +// package represents its parent. +func parentCancelCtx(parent Context) (*cancelCtx, bool) { + for { + switch c := parent.(type) { + case *cancelCtx: + return c, true + case *timerCtx: + return c.cancelCtx, true + case *valueCtx: + parent = c.Context + default: + return nil, false + } + } +} + +// removeChild removes a context from its parent. +func removeChild(parent Context, child canceler) { + p, ok := parentCancelCtx(parent) + if !ok { + return + } + p.mu.Lock() + if p.children != nil { + delete(p.children, child) + } + p.mu.Unlock() +} + +// A canceler is a context type that can be canceled directly. The +// implementations are *cancelCtx and *timerCtx. +type canceler interface { + cancel(removeFromParent bool, err error) + Done() <-chan struct{} +} + +// A cancelCtx can be canceled. When canceled, it also cancels any children +// that implement canceler. +type cancelCtx struct { + Context + + done chan struct{} // closed by the first cancel call. + + mu sync.Mutex + children map[canceler]bool // set to nil by the first cancel call + err error // set to non-nil by the first cancel call +} + +func (c *cancelCtx) Done() <-chan struct{} { + return c.done +} + +func (c *cancelCtx) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *cancelCtx) String() string { + return fmt.Sprintf("%v.WithCancel", c.Context) +} + +// cancel closes c.done, cancels each of c's children, and, if +// removeFromParent is true, removes c from its parent's children. +func (c *cancelCtx) cancel(removeFromParent bool, err error) { + if err == nil { + panic("context: internal error: missing cancel error") + } + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return // already canceled + } + c.err = err + close(c.done) + for child := range c.children { + // NOTE: acquiring the child's lock while holding parent's lock. + child.cancel(false, err) + } + c.children = nil + c.mu.Unlock() + + if removeFromParent { + removeChild(c.Context, c) + } +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return WithCancel(parent) + } + c := &timerCtx{ + cancelCtx: newCancelCtx(parent), + deadline: deadline, + } + propagateCancel(parent, c) + d := deadline.Sub(time.Now()) + if d <= 0 { + c.cancel(true, DeadlineExceeded) // deadline has already passed + return c, func() { c.cancel(true, Canceled) } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.timer = time.AfterFunc(d, func() { + c.cancel(true, DeadlineExceeded) + }) + } + return c, func() { c.cancel(true, Canceled) } +} + +// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to +// implement Done and Err. It implements cancel by stopping its timer then +// delegating to cancelCtx.cancel. +type timerCtx struct { + *cancelCtx + timer *time.Timer // Under cancelCtx.mu. + + deadline time.Time +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { + return c.deadline, true +} + +func (c *timerCtx) String() string { + return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) +} + +func (c *timerCtx) cancel(removeFromParent bool, err error) { + c.cancelCtx.cancel(false, err) + if removeFromParent { + // Remove this timerCtx from its parent cancelCtx's children. + removeChild(c.cancelCtx.Context, c) + } + c.mu.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.mu.Unlock() +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return &valueCtx{parent, key, val} +} + +// A valueCtx carries a key-value pair. It implements Value for that key and +// delegates all other calls to the embedded Context. +type valueCtx struct { + Context + key, val interface{} +} + +func (c *valueCtx) String() string { + return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) +} + +func (c *valueCtx) Value(key interface{}) interface{} { + if c.key == key { + return c.val + } + return c.Context.Value(key) +} diff --git a/vendor/golang.org/x/net/context/withtimeout_test.go b/vendor/golang.org/x/net/context/withtimeout_test.go new file mode 100644 index 0000000..a6754dc --- /dev/null +++ b/vendor/golang.org/x/net/context/withtimeout_test.go @@ -0,0 +1,26 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + "fmt" + "time" + + "golang.org/x/net/context" +) + +func ExampleWithTimeout() { + // Pass a context with a timeout to tell a blocking function that it + // should abandon its work after the timeout elapses. + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + select { + case <-time.After(200 * time.Millisecond): + fmt.Println("overslept") + case <-ctx.Done(): + fmt.Println(ctx.Err()) // prints "context deadline exceeded" + } + // Output: + // context deadline exceeded +} From 48e68affb3d40f7395940f71ccfd3d44b79066dc Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Sun, 12 Feb 2017 13:24:19 +0200 Subject: [PATCH 098/104] reapply: support secPart (subsecond) resolution in timestamp and datetime --- .../siddontang/go-mysql/replication/row_event.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vendor/github.com/siddontang/go-mysql/replication/row_event.go b/vendor/github.com/siddontang/go-mysql/replication/row_event.go index b345ae1..0bcb5f6 100644 --- a/vendor/github.com/siddontang/go-mysql/replication/row_event.go +++ b/vendor/github.com/siddontang/go-mysql/replication/row_event.go @@ -600,7 +600,7 @@ func decodeBit(data []byte, nbits int, length int) (value int64, err error) { return } -func decodeTimestamp2(data []byte, dec uint16) (string, int, error) { +func decodeTimestamp2(data []byte, dec uint16) (interface{}, int, error) { //get timestamp binary length n := int(4 + (dec+1)/2) sec := int64(binary.BigEndian.Uint32(data[0:4])) @@ -619,12 +619,12 @@ func decodeTimestamp2(data []byte, dec uint16) (string, int, error) { } t := time.Unix(sec, usec*1000) - return t.Format(TimeFormat), n, nil + return t, n, nil } const DATETIMEF_INT_OFS int64 = 0x8000000000 -func decodeDatetime2(data []byte, dec uint16) (string, int, error) { +func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { //get datetime binary length n := int(5 + (dec+1)/2) @@ -666,9 +666,9 @@ func decodeDatetime2(data []byte, dec uint16) (string, int, error) { hour := int((hms >> 12)) if secPart != 0 { - return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%06d", year, month, day, hour, minute, second, secPart), n, nil + return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%d", year, month, day, hour, minute, second, secPart), n, nil // commented by Shlomi Noach. Yes I know about `git blame` } - return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second), n, nil + return fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second), n, nil // commented by Shlomi Noach. Yes I know about `git blame` } const TIMEF_OFS int64 = 0x800000000000 From 8db8e127ebba537a5d86c5c7224fd7796597702c Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 21 Feb 2017 10:34:24 -0700 Subject: [PATCH 099/104] supporting --timestamp-old-table --- go/base/context.go | 12 +++++++----- go/cmd/gh-ost/main.go | 1 + localtests/test.sh | 1 + 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 314a1d5..d6cd6ce 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -121,6 +121,7 @@ type MigrationContext struct { OkToDropTable bool InitiallyDropOldTable bool InitiallyDropGhostTable bool + TimestampOldTable bool // Should old table name include a timestamp CutOverType CutOver ReplicaServerId uint @@ -234,11 +235,12 @@ func (this *MigrationContext) GetGhostTableName() string { // GetOldTableName generates the name of the "old" table, into which the original table is renamed. func (this *MigrationContext) GetOldTableName() string { - if this.TestOnReplica { - return fmt.Sprintf("_%s_ght", this.OriginalTableName) - } - if this.MigrateOnReplica { - return fmt.Sprintf("_%s_ghr", this.OriginalTableName) + if this.TimestampOldTable { + t := this.StartTime + timestamp := fmt.Sprintf("%d%02d%02d%02d%02d%02d", + t.Year(), t.Month(), t.Day(), + t.Hour(), t.Minute(), t.Second()) + return fmt.Sprintf("_%s_%s_del", this.OriginalTableName, timestamp) } return fmt.Sprintf("_%s_del", this.OriginalTableName) } diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 8e5e20c..238dc81 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -77,6 +77,7 @@ func main() { flag.BoolVar(&migrationContext.OkToDropTable, "ok-to-drop-table", false, "Shall the tool drop the old table at end of operation. DROPping tables can be a long locking operation, which is why I'm not doing it by default. I'm an online tool, yes?") flag.BoolVar(&migrationContext.InitiallyDropOldTable, "initially-drop-old-table", false, "Drop a possibly existing OLD table (remains from a previous run?) before beginning operation. Default is to panic and abort if such table exists") flag.BoolVar(&migrationContext.InitiallyDropGhostTable, "initially-drop-ghost-table", false, "Drop a possibly existing Ghost table (remains from a previous run?) before beginning operation. Default is to panic and abort if such table exists") + flag.BoolVar(&migrationContext.TimestampOldTable, "timestamp-old-table", false, "Use a timestamp in old table name. This makes old table names unique and non conflicting cross migrations") cutOver := flag.String("cut-over", "atomic", "choose cut-over type (default|atomic, two-step)") flag.BoolVar(&migrationContext.ForceNamedCutOverCommand, "force-named-cut-over", false, "When true, the 'unpostpone|cut-over' interactive command must name the migrated table") diff --git a/localtests/test.sh b/localtests/test.sh index 8bbcf6f..12ec77b 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -85,6 +85,7 @@ test_single() { --switch-to-rbr \ --initially-drop-old-table \ --initially-drop-ghost-table \ + --timestamp-old-table \ --throttle-query='select timestampdiff(second, min(last_update), now()) < 5 from _gh_ost_test_ghc' \ --serve-socket-file=/tmp/gh-ost.test.sock \ --initially-drop-socket-file \ From b848b78fcb89ef628286439c1deba0a0aebd5920 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 21 Feb 2017 11:25:29 -0700 Subject: [PATCH 100/104] localtests should not be using --timestamp-old-table --- localtests/test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/localtests/test.sh b/localtests/test.sh index 12ec77b..8bbcf6f 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -85,7 +85,6 @@ test_single() { --switch-to-rbr \ --initially-drop-old-table \ --initially-drop-ghost-table \ - --timestamp-old-table \ --throttle-query='select timestampdiff(second, min(last_update), now()) < 5 from _gh_ost_test_ghc' \ --serve-socket-file=/tmp/gh-ost.test.sock \ --initially-drop-socket-file \ From 06c909bd100ea50f35232b33b844a400fbc0c307 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 21 Feb 2017 17:34:49 -0700 Subject: [PATCH 101/104] Validating table name length --- go/logic/applier.go | 4 ++++ go/mysql/utils.go | 2 ++ 2 files changed, 6 insertions(+) diff --git a/go/logic/applier.go b/go/logic/applier.go index d1d7970..e84b5cf 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -142,6 +142,10 @@ func (this *Applier) ValidateOrDropExistingTables() error { return err } } + if len(this.migrationContext.GetOldTableName()) > mysql.MaxTableNameLength { + log.Fatalf("--timestamp-old-table defined, but resulting table name (%s) is too long (only %d characters allowed)", this.migrationContext.GetOldTableName(), mysql.MaxTableNameLength) + } + if this.tableExists(this.migrationContext.GetOldTableName()) { return fmt.Errorf("Table %s already exists. Panicking. Use --initially-drop-old-table to force dropping it, though I really prefer that you drop it or rename it away", sql.EscapeName(this.migrationContext.GetOldTableName())) } diff --git a/go/mysql/utils.go b/go/mysql/utils.go index c48e07e..4df7639 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -16,6 +16,8 @@ import ( "github.com/outbrain/golib/sqlutils" ) +const MaxTableNameLength = 64 + type ReplicationLagResult struct { Key InstanceKey Lag time.Duration From d15594824ad051aa80fe3b5ddbdd74539ac047a5 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 21 Feb 2017 17:34:58 -0700 Subject: [PATCH 102/104] documenting --timestamp-old-table --- doc/command-line-flags.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index 83bfe05..d266f2c 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -130,3 +130,7 @@ See `approve-renamed-columns` ### test-on-replica Issue the migration on a replica; do not modify data on master. Useful for validating, testing and benchmarking. See [testing-on-replica](testing-on-replica.md) + +### timestamp-old-table + +Makes the _old_ table include a timestamp value. The _old_ table is what the original table is renamed to at the end of a successful migration. For example, if the table is `gh_ost_test`, then the _old_ table would normally be `_gh_ost_test_del`. With `--timestamp-old-table` it would be, for example, `_gh_ost_test_20170221103147_del`. From f11f6f978ffe3a1de46e4e6302c5834340b9a496 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Fri, 24 Feb 2017 14:48:10 -0700 Subject: [PATCH 103/104] mitigating cut-over/write race condition --- go/logic/inspect.go | 4 ++-- go/logic/migrator.go | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 862be16..1d30cb5 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -348,7 +348,7 @@ func (this *Inspector) validateLogSlaveUpdates() error { } if this.migrationContext.IsTungsten { - log.Warning("log_slave_updates not found on %s:%d, but --tungsten provided, so I'm proceeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + log.Warningf("log_slave_updates not found on %s:%d, but --tungsten provided, so I'm proceeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) return nil } @@ -357,7 +357,7 @@ func (this *Inspector) validateLogSlaveUpdates() error { } if this.migrationContext.InspectorIsAlsoApplier() { - log.Warning("log_slave_updates not found on %s:%d, but executing directly on master, so I'm proceeeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + log.Warningf("log_slave_updates not found on %s:%d, but executing directly on master, so I'm proceeeding", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) return nil } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index de3a018..94b5e54 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -1000,11 +1000,13 @@ func (this *Migrator) iterateChunks() error { for { if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { // Done + // There's another such check down the line return nil } copyRowsFunc := func() error { if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { - // Done + // Done. + // There's another such check down the line return nil } hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() @@ -1016,6 +1018,17 @@ func (this *Migrator) iterateChunks() error { } // Copy task: applyCopyRowsFunc := func() error { + if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { + // No need for more writes. + // This is the de-facto place where we avoid writing in the event of completed cut-over. + // There could _still_ be a race condition, but that's as close as we can get. + // What about the race condition? Well, there's actually no data integrity issue. + // when rowCopyCompleteFlag==1 that means **guaranteed** all necessary rows have been copied. + // But some are still then collected at the binary log, and these are the ones we're trying to + // not apply here. If the race condition wins over us, then we just attempt to apply onto the + // _ghost_ table, which no longer exists. So, bothering error messages and all, but no damage. + return nil + } _, rowsAffected, _, err := this.applier.ApplyIterationInsertQuery() if err != nil { return terminateRowIteration(err) From 7f728729e3352af80f0338a46ad9be376bf27035 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Tue, 7 Mar 2017 08:18:38 +0200 Subject: [PATCH 104/104] documenting gh-ost-on-start-replication hook --- doc/hooks.md | 1 + resources/hooks-sample/gh-ost-on-start-replication-hook | 6 ++++++ resources/hooks-sample/gh-ost-on-stop-replication-hook | 1 + 3 files changed, 8 insertions(+) create mode 100644 resources/hooks-sample/gh-ost-on-start-replication-hook diff --git a/doc/hooks.md b/doc/hooks.md index 03b1501..43ddb7e 100644 --- a/doc/hooks.md +++ b/doc/hooks.md @@ -44,6 +44,7 @@ The full list of supported hooks is best found in code: [hooks.go](https://githu - `gh-ost-on-interactive-command` - `gh-ost-on-row-copy-complete` - `gh-ost-on-stop-replication` +- `gh-ost-on-start-replication` - `gh-ost-on-begin-postponed` - `gh-ost-on-before-cut-over` - `gh-ost-on-success` diff --git a/resources/hooks-sample/gh-ost-on-start-replication-hook b/resources/hooks-sample/gh-ost-on-start-replication-hook new file mode 100644 index 0000000..b476b99 --- /dev/null +++ b/resources/hooks-sample/gh-ost-on-start-replication-hook @@ -0,0 +1,6 @@ +#!/bin/bash + +# Sample hook file for gh-ost-on-start-replication +# Useful for RDS/Aurora setups, see https://github.com/github/gh-ost/issues/163 + +echo "$(date) gh-ost-on-start-replication $GH_OST_DATABASE_NAME.$GH_OST_TABLE_NAME $GH_OST_MIGRATED_HOST" >> /tmp/gh-ost.log diff --git a/resources/hooks-sample/gh-ost-on-stop-replication-hook b/resources/hooks-sample/gh-ost-on-stop-replication-hook index 232ad95..629e4dc 100755 --- a/resources/hooks-sample/gh-ost-on-stop-replication-hook +++ b/resources/hooks-sample/gh-ost-on-stop-replication-hook @@ -1,5 +1,6 @@ #!/bin/bash # Sample hook file for gh-ost-on-stop-replication +# Useful for RDS/Aurora setups, see https://github.com/github/gh-ost/issues/163 echo "$(date) gh-ost-on-stop-replication $GH_OST_DATABASE_NAME.$GH_OST_TABLE_NAME $GH_OST_MIGRATED_HOST" >> /tmp/gh-ost.log